diff --git a/.github/scripts/eval_chat_config.py b/.github/scripts/eval_chat_config.py
index e2463c0f39..74ae7a8968 100644
--- a/.github/scripts/eval_chat_config.py
+++ b/.github/scripts/eval_chat_config.py
@@ -1,7 +1,7 @@
from copy import deepcopy
from mmengine.config import read_base
-from opencompass.models import TurboMindModel, TurboMindModelwithChatTemplate
+from opencompass.models import TurboMindModelwithChatTemplate
with read_base():
# choose a list of datasets
@@ -84,6 +84,8 @@
models as hf_mistral_chat_7b # noqa: F401, E501
from opencompass.configs.models.mistral.hf_mixtral_8x7b_instruct_v0_1 import \
models as hf_mixtral_chat_8x7b # noqa: F401, E501
+ from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import \
+ models as lmdeploy_qwen2_5_7b_instruct # noqa: F401, E501
from opencompass.configs.models.qwen.hf_qwen1_5_7b_chat import \
models as hf_qwen1_5_chat_7b # noqa: F401, E501
from opencompass.configs.models.qwen.hf_qwen1_5_moe_a2_7b_chat import \
@@ -146,10 +148,8 @@
turbomind_internlm2_5_7b_chat_4bits = deepcopy(*lmdeploy_internlm2_5_7b_chat)
turbomind_internlm2_5_7b_chat_kvint4 = deepcopy(*lmdeploy_internlm2_5_7b_chat)
turbomind_internlm2_5_7b_chat_kvint8 = deepcopy(*lmdeploy_internlm2_5_7b_chat)
-turbomind_internlm2_5_7b_chat_batch1 = deepcopy(*lmdeploy_internlm2_5_7b_chat)
-turbomind_internlm2_5_7b_chat_batch1_4bits = deepcopy(
- *lmdeploy_internlm2_5_7b_chat)
pytorch_internlm2_5_7b_chat = deepcopy(*lmdeploy_internlm2_5_7b_chat)
+pytorch_internlm2_5_7b_chat_w8a8 = deepcopy(*lmdeploy_internlm2_5_7b_chat)
# ===== Configs for internlm/internlm2_5_20b_chat =====
turbomind_internlm2_5_20b_chat = deepcopy(*lmdeploy_internlm2_5_20b_chat)
@@ -181,26 +181,6 @@
turbomind_qwen_7b_chat_kvint8 = deepcopy(*lmdeploy_qwen_7b_chat)
pytorch_qwen_7b_chat = deepcopy(*lmdeploy_qwen_7b_chat)
-# ===== Configs for meta-llama/Llama-2-7b-chat-hf =====
-turbomind_llama2_7b_chat = dict(type=TurboMindModel,
- abbr='tb_llama2_chat_7b',
- path='meta-llama/Llama-2-7b-chat-hf',
- engine_config=dict(session_len=MAX_SESSION_LEN,
- max_batch_size=128),
- gen_config=dict(top_k=1,
- top_p=0.8,
- temperature=1.0,
- max_new_tokens=MAX_NEW_TOKENS),
- max_out_len=MAX_NEW_TOKENS,
- max_seq_len=MAX_SESSION_LEN,
- batch_size=128,
- meta_template=llama2_meta_template,
- run_cfg=dict(num_gpus=1),
- end_str='[INST]')
-turbomind_llama2_7b_chat_4bits = deepcopy(turbomind_llama2_7b_chat)
-turbomind_llama2_7b_chat_kvint4 = deepcopy(turbomind_llama2_7b_chat)
-turbomind_llama2_7b_chat_kvint8 = deepcopy(turbomind_llama2_7b_chat)
-
# ===== Configs for meta-llama/Meta-Llama-3-8B-Instruct =====
turbomind_llama3_8b_instruct = deepcopy(*lmdeploy_llama3_8b_instruct)
turbomind_llama3_8b_instruct_4bits = deepcopy(*lmdeploy_llama3_8b_instruct)
@@ -218,6 +198,7 @@
turbomind_llama3_1_8b_instruct_kvint8 = deepcopy(
turbomind_llama3_1_8b_instruct)
pytorch_llama3_1_8b_instruct = deepcopy(turbomind_llama3_1_8b_instruct)
+pytorch_llama3_1_8b_instruct_w8a8 = deepcopy(turbomind_llama3_1_8b_instruct)
# ===== Configs for Qwen/Qwen2-7B-Instruct =====
turbomind_qwen2_7b_instruct = deepcopy(*lmdeploy_qwen2_7b_instruct)
@@ -225,17 +206,36 @@
turbomind_qwen2_7b_instruct_kvint4 = deepcopy(*lmdeploy_qwen2_7b_instruct)
turbomind_qwen2_7b_instruct_kvint8 = deepcopy(*lmdeploy_qwen2_7b_instruct)
pytorch_qwen2_7b_instruct = deepcopy(*lmdeploy_qwen2_7b_instruct)
+pytorch_qwen2_7b_instruct_w8a8 = deepcopy(*lmdeploy_qwen2_7b_instruct)
+
+# ===== Configs for Qwen/Qwen25-7B-Instruct =====
+turbomind_qwen2_5_7b_instruct = deepcopy(*lmdeploy_qwen2_5_7b_instruct)
+turbomind_qwen2_5_7b_instruct_4bits = deepcopy(*lmdeploy_qwen2_5_7b_instruct)
+turbomind_qwen2_5_7b_instruct_kvint4 = deepcopy(*lmdeploy_qwen2_5_7b_instruct)
+turbomind_qwen2_5_7b_instruct_kvint8 = deepcopy(*lmdeploy_qwen2_5_7b_instruct)
+pytorch_qwen2_5_7b_instruct = deepcopy(*lmdeploy_qwen2_5_7b_instruct)
+pytorch_qwen2_5_7b_instruct_w8a8 = deepcopy(*lmdeploy_qwen2_5_7b_instruct)
+
+# ===== Configs for meta-llama/Llama-2-7b-chat-hf =====
+turbomind_llama2_7b_chat = deepcopy(*lmdeploy_llama2_7b_chat)
+turbomind_llama2_7b_chat_4bits = deepcopy(*lmdeploy_llama2_7b_chat)
+turbomind_llama2_7b_chat_kvint4 = deepcopy(*lmdeploy_llama2_7b_chat)
+turbomind_llama2_7b_chat_kvint8 = deepcopy(*lmdeploy_llama2_7b_chat)
for model in [v for k, v in locals().items() if k.startswith('turbomind_')]:
- model['engine_config']['max_batch_size'] = 128
+ model['engine_config']['max_batch_size'] = 1
model['gen_config']['do_sample'] = False
- model['batch_size'] = 128
+ model['batch_size'] = 100
for model in [v for k, v in locals().items() if k.endswith('_4bits')]:
model['engine_config']['model_format'] = 'awq'
model['abbr'] = model['abbr'] + '_4bits'
model['path'] = model['path'] + '-inner-4bits'
+for model in [v for k, v in locals().items() if k.endswith('_w8a8')]:
+ model['abbr'] = model['abbr'] + '_w8a8'
+ model['path'] = model['path'] + '-inner-w8a8'
+
for model in [v for k, v in locals().items() if k.endswith('_kvint4')]:
model['engine_config']['quant_policy'] = 4
model['abbr'] = model['abbr'] + '_kvint4'
@@ -247,24 +247,19 @@
for model in [v for k, v in locals().items() if k.startswith('pytorch_')]:
model['abbr'] = model['abbr'].replace('turbomind', 'pytorch')
model['backend'] = 'pytorch'
- model['engine_config']['max_batch_size'] = 64
- model['gen_config']['do_sample'] = False
- model['batch_size'] = 64
-
-for model in [v for k, v in locals().items() if '_batch1' in k]:
- model['abbr'] = model['abbr'] + '_batch1'
model['engine_config']['max_batch_size'] = 1
- model['batch_size'] = 1
+ model['gen_config']['do_sample'] = False
+ model['batch_size'] = 100
basic_pytorch_chat_tp1 = dict(type=TurboMindModelwithChatTemplate,
engine_config=dict(session_len=MAX_SESSION_LEN,
- max_batch_size=64,
+ max_batch_size=1,
tp=1),
gen_config=dict(do_sample=False,
max_new_tokens=MAX_NEW_TOKENS),
max_out_len=MAX_NEW_TOKENS,
max_seq_len=MAX_SESSION_LEN,
- batch_size=64,
+ batch_size=100,
run_cfg=dict(num_gpus=1))
# ===== Configs for Qwen/Qwen1.5-MoE-A2.7B-Chat =====
@@ -277,6 +272,13 @@
pytorch_gemma_2_9b_it['abbr'] = 'pytorch_gemma_2_9b_it'
pytorch_gemma_2_9b_it['path'] = 'google/gemma-2-9b-it'
+# ===== Configs for google/gemma2-27b-it =====
+pytorch_gemma_2_27b_it = deepcopy(basic_pytorch_chat_tp1)
+pytorch_gemma_2_27b_it['abbr'] = 'pytorch_gemma_2_27b_it'
+pytorch_gemma_2_27b_it['path'] = 'google/gemma-2-27b-it'
+pytorch_gemma_2_27b_it['run_cfg']['num_gpus'] = 2
+pytorch_gemma_2_27b_it['engine_config']['tp'] = 2
+
race_datasets = [race_datasets[1]]
# Summarizer
diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml
index bd3876f9ed..e75f728783 100644
--- a/.github/workflows/benchmark.yml
+++ b/.github/workflows/benchmark.yml
@@ -88,7 +88,7 @@ jobs:
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Clone repository
- uses: actions/checkout@v3
+ uses: actions/checkout@v2
if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}
with:
repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}
@@ -105,10 +105,8 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install -e /root/packages/AutoAWQ_kernels
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r /nvme/qa_test_models/offline_pkg/requirements.txt
- name: Install lmdeploy
if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}
@@ -148,9 +146,15 @@ jobs:
needs: [benchmark]
timeout-minutes: 5
runs-on: [self-hosted, linux-a100]
+ container:
+ image: openmmlab/lmdeploy:latest-cu11
+ options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never"
+ volumes:
+ - /nvme/qa_test_models:/nvme/qa_test_models
+ - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Clone repository
- uses: actions/checkout@v3
+ uses: actions/checkout@v2
with:
repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}
ref: ${{github.event.inputs.repo_ref || 'main'}}
diff --git a/.github/workflows/daily_ete_test.yml b/.github/workflows/daily_ete_test.yml
index dbacfc32f5..d6299e163a 100644
--- a/.github/workflows/daily_ete_test.yml
+++ b/.github/workflows/daily_ete_test.yml
@@ -130,7 +130,7 @@ jobs:
needs: download_pkgs
if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'quant') )}}
runs-on: [self-hosted, linux-a100]
- timeout-minutes: 120
+ timeout-minutes: 150
env:
PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA
MODELSCOPE_CACHE: /root/modelscope_hub
@@ -149,15 +149,14 @@ jobs:
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Copy repository and Artifacts
- run: cp -r ${{env.TEST_CODE_PATH}}/. .
+ run: |
+ cp -r ${{env.TEST_CODE_PATH}}/. .
- name: Install lmdeploy - dependency
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install -e /root/packages/AutoAWQ_kernels
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -166,7 +165,6 @@ jobs:
pip install ${{env.DEEPSEEK_VL}} --no-deps
- name: Check env
run: |
- pip install transformers
pip uninstall -y nvidia-nccl-cu11
python3 -m pip list
lmdeploy check_env
@@ -244,20 +242,20 @@ jobs:
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Copy repository and Artifacts
- run: cp -r ${{env.TEST_CODE_PATH}}/. .
+ run: |
+ cp -r ${{env.TEST_CODE_PATH}}/. .
- name: Install lmdeploy - dependency
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install -e /root/packages/AutoAWQ_kernels
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
python3 -m pip install lmdeploy-*.whl --no-deps
python3 -m pip install -r requirements/test.txt
+ rm -rf ${{env.DEEPSEEK_VL}}/build
pip install ${{env.DEEPSEEK_VL}} --no-deps
- name: Check env
run: |
@@ -286,6 +284,8 @@ jobs:
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true
pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
+ pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
+ mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
- name: Test lmdeploy - pipeline
continue-on-error: true
if: matrix.function == 'pipeline'
@@ -294,6 +294,8 @@ jobs:
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true
pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
+ pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
+ mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
- name: Test lmdeploy - restful
continue-on-error: true
if: matrix.function == 'restful'
@@ -302,6 +304,8 @@ jobs:
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true
pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
+ pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
+ mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
- name: Test lmdeploy - restful workspace
continue-on-error: true
if: matrix.backend == 'turbomind' && matrix.model == 'llm' && matrix.function == 'restful'
@@ -310,6 +314,8 @@ jobs:
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true
pytest autotest/tools/restful/test_restful_chat_workspace.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
+ pytest autotest/tools/restful/test_restful_chat_workspace.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
+ mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
- name: Test lmdeploy - local testcase
if: matrix.backend == 'turbomind' && matrix.model == 'llm' && matrix.function == 'local_case'
run: |
@@ -344,15 +350,14 @@ jobs:
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Copy repository and Artifacts
- run: cp -r ${{env.TEST_CODE_PATH}}/. .
+ run: |
+ cp -r ${{env.TEST_CODE_PATH}}/. .
- name: Install lmdeploy - dependency
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install -e /root/packages/AutoAWQ_kernels
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -436,15 +441,14 @@ jobs:
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Copy repository and Artifacts
- run: cp -r ${{env.TEST_CODE_PATH}}/. .
+ run: |
+ cp -r ${{env.TEST_CODE_PATH}}/. .
- name: Install lmdeploy - dependency
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install -e /root/packages/AutoAWQ_kernels
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -497,15 +501,14 @@ jobs:
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Copy repository and Artifacts
- run: cp -r ${{env.TEST_CODE_PATH}}/. .
+ run: |
+ cp -r ${{env.TEST_CODE_PATH}}/. .
- name: Install lmdeploy - dependency
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install -e /root/packages/AutoAWQ_kernels
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -560,15 +563,14 @@ jobs:
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Copy repository and Artifacts
- run: cp -r ${{env.TEST_CODE_PATH}}/. .
+ run: |
+ cp -r ${{env.TEST_CODE_PATH}}/. .
- name: Install lmdeploy - dependency
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install -e /root/packages/AutoAWQ_kernels
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -600,7 +602,7 @@ jobs:
run: |
export LMDEPLOY_DIR=$(pwd)
- python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, turbomind_internlm2_5_7b_chat_batch1, turbomind_internlm2_5_7b_chat_batch1_4bits, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it]" "[*race_datasets, *gsm8k_datasets, *ifeval_datasets]" /root/evaluation-reports/${{ github.run_id }} chat true
+ python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat_w8a8, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct_w8a8, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, turbomind_qwen2_5_7b_instruct, pytorch_qwen2_5_7b_instruct, pytorch_qwen2_5_7b_instruct_w8a8, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it, pytorch_gemma_2_27b_it]" "[*race_datasets, *gsm8k_datasets, *ifeval_datasets]" /root/evaluation-reports/${{ github.run_id }} chat true
- name: Evaluate base models
if: matrix.evaluate_type == 'base'
run: |
@@ -622,11 +624,17 @@ jobs:
needs: [test_benchmark]
timeout-minutes: 5
runs-on: [self-hosted, linux-a100]
+ container:
+ image: openmmlab/lmdeploy:latest-cu11
+ options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never"
+ volumes:
+ - /nvme/qa_test_models:/nvme/qa_test_models
+ - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
env:
BENCHMARK_REPORT_DIR: /nvme/qa_test_models/benchmark-reports/${{ github.run_id }}
steps:
- name: Clone repository
- uses: actions/checkout@v3
+ uses: actions/checkout@v2
with:
repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}
ref: ${{github.event.inputs.repo_ref || 'main'}}
diff --git a/.github/workflows/daily_ete_test_v100.yml b/.github/workflows/daily_ete_test_v100.yml
index 8a662b85f5..343cfdea50 100644
--- a/.github/workflows/daily_ete_test_v100.yml
+++ b/.github/workflows/daily_ete_test_v100.yml
@@ -158,8 +158,7 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -167,7 +166,6 @@ jobs:
python3 -m pip install -r requirements/test.txt
- name: Check env
run: |
- pip install triton==3.0.0
pip uninstall -y nvidia-nccl-cu11
python3 -m pip list
lmdeploy check_env
@@ -245,8 +243,7 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -254,7 +251,6 @@ jobs:
python3 -m pip install -r requirements/test.txt
- name: Check env
run: |
- pip install triton==3.0.0
pip uninstall -y nvidia-nccl-cu11
python3 -m pip list
lmdeploy check_env
@@ -345,8 +341,7 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -354,7 +349,6 @@ jobs:
python3 -m pip install -r requirements/test.txt
- name: Check env
run: |
- pip install triton==3.0.0
pip uninstall -y nvidia-nccl-cu11
python3 -m pip list
lmdeploy check_env
@@ -437,8 +431,7 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -446,7 +439,6 @@ jobs:
python3 -m pip install -r requirements/test.txt
- name: Check env
run: |
- pip install triton==3.0.0
pip uninstall -y nvidia-nccl-cu11
python3 -m pip list
lmdeploy check_env
@@ -498,8 +490,7 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -507,7 +498,6 @@ jobs:
python3 -m pip install -r requirements/test.txt
- name: Check env
run: |
- pip install triton==3.0.0
pip uninstall -y nvidia-nccl-cu11
python3 -m pip list
lmdeploy check_env
@@ -560,8 +550,7 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -575,7 +564,6 @@ jobs:
echo "OPENCOMPASS_DIR=$(pwd)" >> $GITHUB_ENV
- name: Check env
run: |
- pip install triton==3.0.0
pip uninstall -y nvidia-nccl-cu11
python3 -m pip list
lmdeploy check_env
@@ -593,13 +581,13 @@ jobs:
run: |
export LMDEPLOY_DIR=$(pwd)
- python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, turbomind_internlm2_5_7b_chat_batch1, turbomind_internlm2_5_7b_chat_batch1_4bits, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it]" "[*race_datasets, *gsm8k_datasets, *ifeval_datasets]" /root/evaluation-reports/${{ github.run_id }} chat true
+ python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it]" "[*race_datasets, *gsm8k_datasets, *ifeval_datasets]" /root/evaluation-reports/${{ github.run_id }} chat true
- name: Evaluate base models
if: matrix.evaluate_type == 'base'
run: |
export LMDEPLOY_DIR=$(pwd)
- python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_5_7b, turbomind_qwen2_5_14b, turbomind_internlm2_5_7b_batch1]" "[*race_datasets, *gsm8k_datasets, *gpqa_datasets, *winogrande_datasets]" /root/evaluation-reports/${{ github.run_id }} base true
+ python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_5_7b, turbomind_qwen2_5_14b]" "[*race_datasets, *gsm8k_datasets, *gpqa_datasets, *winogrande_datasets]" /root/evaluation-reports/${{ github.run_id }} base true
- name: Clear workspace
if: always()
run: |
diff --git a/.github/workflows/evaluate.yml b/.github/workflows/evaluate.yml
index b6ab89f595..dbfff04fe2 100644
--- a/.github/workflows/evaluate.yml
+++ b/.github/workflows/evaluate.yml
@@ -17,7 +17,7 @@ on:
required: true
description: 'Tested TurboMind models list. eg. [internlm_chat_7b,internlm_chat_7b_w8a16]'
type: string
- default: '[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, turbomind_internlm2_5_7b_chat_batch1, turbomind_internlm2_5_7b_chat_batch1_4bits, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it, turbomind_internlm2_chat_7b_4bits, turbomind_internlm2_chat_7b_kvint4, turbomind_internlm2_chat_7b_kvint8, turbomind_internlm2_5_7b_chat_4bits, turbomind_internlm2_5_7b_chat_kvint4, turbomind_internlm2_5_7b_chat_kvint8, turbomind_internlm2_5_20b_chat_4bits, turbomind_internlm2_5_20b_chat_kvint4, turbomind_internlm2_5_20b_chat_kvint8, turbomind_qwen1_5_7b_chat_4bits, turbomind_qwen1_5_7b_chat_kvint4, turbomind_qwen1_5_7b_chat_kvint8, turbomind_llama2_7b_chat_4bits, turbomind_llama2_7b_chat_kvint4, turbomind_llama2_7b_chat_kvint8, turbomind_llama3_8b_instruct_4bits, turbomind_llama3_8b_instruct_kvint4, turbomind_llama3_8b_instruct_kvint8, turbomind_llama3_1_8b_instruct_4bits, turbomind_llama3_1_8b_instruct_kvint4, turbomind_llama3_1_8b_instruct_kvint8, turbomind_qwen2_7b_instruct_4bits, turbomind_qwen2_7b_instruct_kvint8]'
+ default: '[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, turbomind_qwen2_5_7b_instruct, pytorch_qwen2_5_7b_instruct, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it, pytorch_gemma_2_27b_it, turbomind_internlm2_chat_7b_kvint4, turbomind_internlm2_chat_7b_kvint8, turbomind_internlm2_5_7b_chat_4bits, turbomind_internlm2_5_7b_chat_kvint4, turbomind_internlm2_5_7b_chat_kvint8, pytorch_internlm2_5_7b_chat_w8a8, turbomind_internlm2_5_20b_chat_4bits, turbomind_internlm2_5_20b_chat_kvint4, turbomind_internlm2_5_20b_chat_kvint8, turbomind_qwen1_5_7b_chat_4bits, turbomind_qwen1_5_7b_chat_kvint4, turbomind_qwen1_5_7b_chat_kvint8, turbomind_llama2_7b_chat_4bits, turbomind_llama2_7b_chat_kvint4, turbomind_llama2_7b_chat_kvint8, turbomind_llama3_8b_instruct_4bits, turbomind_llama3_8b_instruct_kvint4, turbomind_llama3_8b_instruct_kvint8, turbomind_llama3_1_8b_instruct_4bits, turbomind_llama3_1_8b_instruct_kvint4, turbomind_llama3_1_8b_instruct_kvint8, pytorch_llama3_1_8b_instruct_w8a8, turbomind_qwen2_7b_instruct_4bits, turbomind_qwen2_7b_instruct_kvint8, turbomind_qwen2_5_7b_instruct_4bits, turbomind_qwen2_5_7b_instruct_kvint8, pytorch_qwen2_5_7b_instruct_w8a8]'
chat_datasets:
required: true
description: 'Tested datasets list. eg. [*bbh_datasets,*ceval_datasets,*cmmlu_datasets,*GaokaoBench_datasets,*gpqa_datasets,*gsm8k_datasets,*hellaswag_datasets,*humaneval_datasets,*ifeval_datasets,*math_datasets,*sanitized_mbpp_datasets,*mmlu_datasets,*nq_datasets,*race_datasets,*TheoremQA_datasets,*triviaqa_datasets,*winogrande_datasets,*crowspairs_datasets]'
@@ -25,7 +25,7 @@ on:
default: '[*mmlu_datasets, *gsm8k_datasets, *ifeval_datasets]'
base_models:
required: true
- description: 'Tested TurboMind models list. eg. [turbomind_internlm2_5_7b, turbomind_qwen2_7b, turbomind_internlm2_5_7b_batch1]'
+ description: 'Tested TurboMind models list. eg. [turbomind_internlm2_5_7b, turbomind_qwen2_7b]'
type: string
default: '[turbomind_internlm2_5_7b, turbomind_internlm2_5_7b_4bits, turbomind_internlm2_5_7b_batch1, turbomind_internlm2_5_7b_batch1_4bits, turbomind_qwen2_7b, turbomind_qwen2_5_7b, turbomind_qwen2_5_14b]'
baes_datasets:
@@ -133,8 +133,10 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install -e /root/packages/AutoAWQ_kernels
+ python3 -m pip install /root/packages/autoawq-*.whl --no-deps
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r /root/models/offline_pkg/requirements.txt
- name: Install lmdeploy
if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}
diff --git a/.github/workflows/pr_ete_test.yml b/.github/workflows/pr_ete_test.yml
index 3a19ebe870..2d1c4b63f5 100644
--- a/.github/workflows/pr_ete_test.yml
+++ b/.github/workflows/pr_ete_test.yml
@@ -10,7 +10,7 @@ on:
- "3rdparty/**"
- "lmdeploy/**"
- "requirements/**"
- - "requirements.txt"
+ - "requirements_cuda.txt"
- "CMakeLists.txt"
- "setup.py"
workflow_dispatch:
@@ -68,7 +68,7 @@ jobs:
export PATH=$PATH:/usr/local/openmpi/bin
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/openmpi/lib
python3 -m pip install cmake packaging wheel transformers_stream_generator transformers datasets openai einops timm decord
- python3 -m pip install -r requirements.txt -r requirements/test.txt -r requirements/build.txt
+ python3 -m pip install -r requirements_cuda.txt -r requirements/test.txt -r requirements/build.txt
mkdir -p build && cd build &&\
sh ../generate.sh &&\
ninja -j$(nproc) && ninja install &&\
diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml
index ec6db0682d..3a459050ec 100644
--- a/.github/workflows/unit-test.yml
+++ b/.github/workflows/unit-test.yml
@@ -10,7 +10,7 @@ on:
- "3rdparty/**"
- "lmdeploy/**"
- "requirements/**"
- - "requirements.txt"
+ - "requirements_cuda.txt"
- "CMakeLists.txt"
- "setup.py"
push:
@@ -24,7 +24,7 @@ on:
- "3rdparty/**"
- "lmdeploy/**"
- "requirements/**"
- - "requirements.txt"
+ - "requirements_cuda.txt"
- "CMakeLists.txt"
- "setup.py"
tags:
@@ -39,6 +39,7 @@ jobs:
options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e CUDA_VISIBLE_DEVICES=2,3 --pull never"
volumes:
- /nvme/share_data/github-actions/pip-cache:/root/.cache/pip
+ - /nvme/share_data/github-actions/hf_home:/root/.cache/huggingface
- /nvme/share_data/github-actions/packages:/root/packages
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
@@ -78,7 +79,7 @@ jobs:
python3 -m pip install pynvml packaging protobuf transformers_stream_generator
# manually install flash attn
python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
- python3 -m pip install -r requirements.txt -r requirements/test.txt
+ python3 -m pip install -r requirements_cuda.txt -r requirements/test.txt
python3 -m pip install .
- name: Check env
run: |
diff --git a/README.md b/README.md
index d160338aa6..8ef7b7994f 100644
--- a/README.md
+++ b/README.md
@@ -125,6 +125,8 @@ For detailed inference benchmarks in more devices and more settings, please refe
Qwen1.5 (0.5B - 110B)
Qwen1.5 - MoE (0.5B - 72B)
Qwen2 (0.5B - 72B)
+ Qwen2-MoE (57BA14B)
+ Qwen2.5 (0.5B - 32B)
Baichuan (7B)
Baichuan2 (7B-13B)
Code Llama (7B - 34B)
@@ -136,6 +138,7 @@ For detailed inference benchmarks in more devices and more settings, please refe
Mistral (7B)
DeepSeek-MoE (16B)
DeepSeek-V2 (16B, 236B)
+ DeepSeek-V2.5 (236B)
Mixtral (8x7B, 8x22B)
Gemma (2B - 7B)
Dbrx (132B)
diff --git a/README_ja.md b/README_ja.md
index fda176229e..77badaac36 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -122,6 +122,8 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
Qwen1.5 (0.5B - 110B)
Qwen1.5 - MoE (0.5B - 72B)
Qwen2 (0.5B - 72B)
+ Qwen2-MoE (57BA14B)
+ Qwen2.5 (0.5B - 32B)
Baichuan (7B)
Baichuan2 (7B-13B)
Code Llama (7B - 34B)
@@ -133,6 +135,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
Mistral (7B)
DeepSeek-MoE (16B)
DeepSeek-V2 (16B, 236B)
+ DeepSeek-V2.5 (236B)
Mixtral (8x7B, 8x22B)
Gemma (2B - 7B)
Dbrx (132B)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 6c24b2e500..9f3cd40a64 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -126,6 +126,8 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
Qwen1.5 (0.5B - 110B)
Qwen1.5 - MoE (0.5B - 72B)
Qwen2 (0.5B - 72B)
+ Qwen2-MoE (57BA14B)
+ Qwen2.5 (0.5B - 32B)
Baichuan (7B)
Baichuan2 (7B-13B)
Code Llama (7B - 34B)
@@ -137,6 +139,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
Mistral (7B)
DeepSeek-MoE (16B)
DeepSeek-V2 (16B, 236B)
+ DeepSeek-V2.5 (236B)
Mixtral (8x7B, 8x22B)
Gemma (2B - 7B)
Dbrx (132B)
diff --git a/autotest/config.yaml b/autotest/config.yaml
index b4fd4e1712..d92e32a595 100644
--- a/autotest/config.yaml
+++ b/autotest/config.yaml
@@ -17,15 +17,21 @@ tp_config:
Meta-Llama-3-1-70B-Instruct: 4
internlm2_5-7b-chat-1m: 4
Qwen2-7B-Instruct-GPTQ-Int4: 2
- InternVL2-40B: 2
+ InternVL2-26B: 2
+ InternVL2-40B: 4
+ InternVL2_5-26B: 2
+ InternVL2_5-38B: 4
MiniCPM-V-2_6: 2
Qwen2.5-72B-Instruct: 4
+ gemma-2-27b-it: 2
+ DeepSeek-V2-Lite-Chat: 2
turbomind_chat_model:
- meta-llama/Llama-3.2-1B-Instruct
- meta-llama/Llama-3.2-3B-Instruct
- meta-llama/Meta-Llama-3-1-8B-Instruct
- meta-llama/Meta-Llama-3-1-8B-Instruct-AWQ
+ - meta-llama/Meta-Llama-3-1-70B-Instruct
- meta-llama/Meta-Llama-3-8B-Instruct
- meta-llama/Llama-2-7b-chat-hf
- internlm/internlm2_5-7b-chat
@@ -35,6 +41,10 @@ turbomind_chat_model:
- internlm/internlm-chat-20b
- internlm/internlm-xcomposer2-4khd-7b
- internlm/internlm-xcomposer2d5-7b
+ - OpenGVLab/InternVL2_5-1B
+ - OpenGVLab/InternVL2_5-8B
+ - OpenGVLab/InternVL2_5-26B
+ - OpenGVLab/InternVL2_5-38B
- OpenGVLab/InternVL2-1B
- OpenGVLab/InternVL2-2B
- OpenGVLab/InternVL2-8B
@@ -42,6 +52,7 @@ turbomind_chat_model:
- OpenGVLab/InternVL2-40B
- OpenGVLab/InternVL-Chat-V1-5
- OpenGVLab/Mini-InternVL-Chat-2B-V1-5
+ - OpenGVLab/InternVL2-Llama3-76B-AWQ
- Qwen/Qwen2-7B-Instruct
- Qwen/Qwen2-7B-Instruct-AWQ
- Qwen/Qwen2-1.5B-Instruct
@@ -51,6 +62,7 @@ turbomind_chat_model:
- Qwen/Qwen-VL-Chat
- Qwen/Qwen2.5-0.5B-Instruct
- Qwen/Qwen2.5-7B-Instruct
+ - Qwen/Qwen2.5-72B-Instruct
- Qwen/Qwen2-7B-Instruct-GPTQ-Int4
- Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4
- mistralai/Mistral-7B-Instruct-v0.3
@@ -69,10 +81,12 @@ turbomind_chat_model:
- THUDM/glm-4-9b-chat
- openbmb/MiniCPM-Llama3-V-2_5
- openbmb/MiniCPM-V-2_6
+ - allenai/Molmo-7B-D-0924
pytorch_chat_model:
- meta-llama/Meta-Llama-3-8B-Instruct
- meta-llama/Meta-Llama-3-1-8B-Instruct
+ - meta-llama/Meta-Llama-3-1-70B-Instruct
- meta-llama/Llama-3.2-1B-Instruct
- meta-llama/Llama-3.2-3B-Instruct
- meta-llama/Llama-3.2-11B-Vision-Instruct
@@ -81,6 +95,10 @@ pytorch_chat_model:
- internlm/internlm2_5-20b-chat
- internlm/internlm2-chat-20b
- internlm/internlm-chat-20b
+ - OpenGVLab/InternVL2_5-1B
+ - OpenGVLab/InternVL2_5-8B
+ - OpenGVLab/InternVL2_5-26B
+ - OpenGVLab/InternVL2_5-38B
- OpenGVLab/InternVL2-1B
- OpenGVLab/InternVL2-2B
- OpenGVLab/InternVL2-4B
@@ -92,10 +110,11 @@ pytorch_chat_model:
- baichuan-inc/Baichuan2-7B-Chat
- baichuan-inc/Baichuan2-13B-Chat
- 01-ai/Yi-6B-Chat
- - liuhaotian/llava-v1.5-13b
- - liuhaotian/llava-v1.6-vicuna-7b
- Qwen/Qwen2-7B-Instruct
- Qwen/Qwen2-1.5B-Instruct
+ - Qwen/Qwen2.5-0.5B-Instruct
+ - Qwen/Qwen2.5-7B-Instruct
+ - Qwen/Qwen2.5-72B-Instruct
- Qwen/Qwen1.5-7B-Chat
- Qwen/Qwen1.5-MoE-A2.7B-Chat
- Qwen/Qwen2-VL-2B-Instruct
@@ -104,6 +123,7 @@ pytorch_chat_model:
- mistralai/Mixtral-8x7B-Instruct-v0.1
- google/gemma-7b-it
- google/gemma-2-9b-it
+ - google/gemma-2-27b-it
- deepseek-ai/deepseek-moe-16b-chat
- deepseek-ai/deepseek-coder-1.3b-instruct
- deepseek-ai/DeepSeek-V2-Lite-Chat
@@ -111,6 +131,7 @@ pytorch_chat_model:
- THUDM/cogvlm2-llama3-chinese-chat-19B
- THUDM/glm-4v-9b
- THUDM/glm-4-9b-chat
+ - openbmb/MiniCPM-V-2_6
- microsoft/Phi-3-mini-4k-instruct
- microsoft/Phi-3-vision-128k-instruct
@@ -122,11 +143,16 @@ turbomind_vl_model:
- deepseek-ai/deepseek-vl-1.3b-chat
- OpenGVLab/InternVL-Chat-V1-5
- OpenGVLab/Mini-InternVL-Chat-2B-V1-5
+ - OpenGVLab/InternVL2_5-1B
+ - OpenGVLab/InternVL2_5-8B
+ - OpenGVLab/InternVL2_5-26B
+ - OpenGVLab/InternVL2_5-38B
- OpenGVLab/InternVL2-1B
- OpenGVLab/InternVL2-2B
- OpenGVLab/InternVL2-8B
- OpenGVLab/InternVL2-26B
- OpenGVLab/InternVL2-40B
+ - OpenGVLab/InternVL2-Llama3-76B-AWQ
- internlm/internlm-xcomposer2d5-7b
- internlm/internlm-xcomposer2-4khd-7b
- openbmb/MiniCPM-Llama3-V-2_5
@@ -136,6 +162,10 @@ pytorch_vl_model:
- meta-llama/Llama-3.2-11B-Vision-Instruct
- OpenGVLab/InternVL-Chat-V1-5
- OpenGVLab/Mini-InternVL-Chat-2B-V1-5
+ - OpenGVLab/InternVL2_5-1B
+ - OpenGVLab/InternVL2_5-8B
+ - OpenGVLab/InternVL2_5-26B
+ - OpenGVLab/InternVL2_5-38B
- OpenGVLab/InternVL2-1B
- OpenGVLab/InternVL2-2B
- OpenGVLab/InternVL2-4B
@@ -148,6 +178,7 @@ pytorch_vl_model:
- THUDM/cogvlm-chat-hf
- THUDM/cogvlm2-llama3-chinese-chat-19B
- THUDM/glm-4v-9b
+ - openbmb/MiniCPM-V-2_6
- microsoft/Phi-3-vision-128k-instruct
- microsoft/Phi-3.5-vision-instruct
@@ -166,6 +197,8 @@ pytorch_base_model:
turbomind_quatization:
no_awq:
+ - meta-llama/Meta-Llama-3-1-70B-Instruct
+ - Qwen/Qwen2.5-72B-Instruct
- Qwen/Qwen1.5-MoE-A2.7B-Chat
- Qwen/Qwen2-VL-2B-Instruct
- Qwen/Qwen2-VL-7B-Instruct
@@ -174,18 +207,28 @@ turbomind_quatization:
- deepseek-ai/deepseek-coder-1.3b-instruct
- deepseek-ai/DeepSeek-V2-Lite-Chat
- codellama/CodeLlama-7b-Instruct-hf
+ - allenai/Molmo-7B-D-0924
gptq:
- internlm/internlm2_5-7b-chat
no_kvint4:
+ - meta-llama/Llama-3.2-1B-Instruct
+ - OpenGVLab/InternVL2-1B
+ - OpenGVLab/InternVL2_5-1B
- openbmb/MiniCPM-V-2_6
- Qwen/Qwen2-7B-Instruct
- Qwen/Qwen2-7B-Instruct-AWQ
- Qwen/Qwen2-1.5B-Instruct
- Qwen/Qwen2.5-0.5B-Instruct
- Qwen/Qwen2.5-7B-Instruct
+ - Qwen/Qwen2.5-72B-Instruct
- Qwen/Qwen2-7B-Instruct-GPTQ-Int4
+ - allenai/Molmo-7B-D-0924
no_kvint8:
+ - deepseek-ai/DeepSeek-V2-Chat
+ no_converted:
- deepseek-ai/DeepSeek-V2-Lite-Chat
+ - Qwen/Qwen2.5-72B-Instruct
+ - meta-llama/Meta-Llama-3-1-70B-Instruct
pytorch_quatization:
awq:
@@ -200,23 +243,39 @@ pytorch_quatization:
- Qwen/Qwen1.5-7B-Chat
- Qwen/Qwen2-7B-Instruct
- Qwen/Qwen2-1.5B-Instruct
- - microsoft/Phi-3-mini-4k-instruct
+ - Qwen/Qwen2.5-7B-Instruct
- Qwen/Qwen2-VL-2B-Instruct
- Qwen/Qwen2-VL-7B-Instruct
+ - microsoft/Phi-3-mini-4k-instruct
w8a8:
- meta-llama/Meta-Llama-3-8B-Instruct
+ - meta-llama/Llama-3.2-1B-Instruct
- meta-llama/Llama-2-7b-chat-hf
- internlm/internlm2-chat-20b
- internlm/internlm2_5-7b-chat
- internlm/internlm2_5-20b-chat
- 01-ai/Yi-6B-Chat
+ - mistralai/Mistral-7B-Instruct-v0.3
+ - Qwen/Qwen1.5-7B-Chat
+ - Qwen/Qwen2-7B-Instruct
+ - Qwen/Qwen2-1.5B-Instruct
+ - Qwen/Qwen2.5-7B-Instruct
+ - microsoft/Phi-3-mini-4k-instruct
- internlm/internlm2_5-20b
- internlm/internlm2_5-7b
+ - meta-llama/Meta-Llama-3-1-8B-Instruct
no_kvint4:
+ - meta-llama/Llama-3.2-1B-Instruct
- OpenGVLab/InternVL2-1B
- OpenGVLab/InternVL2-4B
+ - OpenGVLab/InternVL2_5-1B
- Qwen/Qwen2-7B-Instruct
+ - Qwen/Qwen2-7B-Instruct-AWQ
- Qwen/Qwen2-1.5B-Instruct
+ - Qwen/Qwen2.5-0.5B-Instruct
+ - Qwen/Qwen2.5-7B-Instruct
+ - Qwen/Qwen2.5-72B-Instruct
+ - Qwen/Qwen2-7B-Instruct-GPTQ-Int4
- Qwen/Qwen2-VL-2B-Instruct
- Qwen/Qwen2-VL-7B-Instruct
- deepseek-ai/DeepSeek-V2-Lite-Chat
@@ -247,3 +306,4 @@ benchmark_model:
- mistralai/Mistral-7B-Instruct-v0.3
- mistralai/Mixtral-8x7B-Instruct-v0.1
- deepseek-ai/DeepSeek-V2-Lite-Chat
+ - allenai/Molmo-7B-D-0924
diff --git a/autotest/prompt_case.yaml b/autotest/prompt_case.yaml
index 9a5ed45724..468f3e49d6 100644
--- a/autotest/prompt_case.yaml
+++ b/autotest/prompt_case.yaml
@@ -54,6 +54,7 @@ emoji_case:
- 好
- '!'
- u1f44d
+ - 🌟
traditional_chinese_case:
- 介紹澳門景點,使用繁體:
- contain:
diff --git a/autotest/pytest.ini b/autotest/pytest.ini
index 4c963d5bbd..69dc47fa58 100644
--- a/autotest/pytest.ini
+++ b/autotest/pytest.ini
@@ -5,4 +5,4 @@ python_functions = test_* # test function
pytest_runtest_call.tryfirst = True
filterwarnings = ignore::UserWarning
reruns = 2
-reruns_delay = 10
+reruns_delay = 1
diff --git a/autotest/tools/chat/test_command_chat_hf_pytorch.py b/autotest/tools/chat/test_command_chat_hf_pytorch.py
index 1ae3be338b..e6986ec614 100644
--- a/autotest/tools/chat/test_command_chat_hf_pytorch.py
+++ b/autotest/tools/chat/test_command_chat_hf_pytorch.py
@@ -51,6 +51,27 @@ def test_hf_pytorch_chat_tp2(config, model, cli_case_config, worker_id):
assert result, msg
+@pytest.mark.order(10)
+@pytest.mark.usefixtures('cli_case_config')
+@pytest.mark.hf_pytorch_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model', get_torch_model_list(tp_num=4))
+def test_hf_pytorch_chat_tp4(config, model, cli_case_config, worker_id):
+ usercase = 'chat_testcase'
+ result, chat_log, msg = hf_command_line_test(
+ config,
+ usercase,
+ cli_case_config.get(usercase),
+ model,
+ 'pytorch',
+ cuda_prefix=get_cuda_prefix_by_workerid(worker_id, tp_num=4))
+ if chat_log is not None:
+ allure.attach.file(chat_log,
+ attachment_type=allure.attachment_type.TEXT)
+
+ assert result, msg
+
+
@pytest.mark.order(10)
@pytest.mark.usefixtures('cli_case_config')
@pytest.mark.hf_turbomind_chat
diff --git a/autotest/tools/chat/test_command_chat_hf_turbomind.py b/autotest/tools/chat/test_command_chat_hf_turbomind.py
index 2f13898fec..935a21ee86 100644
--- a/autotest/tools/chat/test_command_chat_hf_turbomind.py
+++ b/autotest/tools/chat/test_command_chat_hf_turbomind.py
@@ -53,6 +53,28 @@ def test_hf_turbomind_chat_tp2(config, model, cli_case_config, worker_id):
assert result, msg
+@pytest.mark.order(10)
+@pytest.mark.usefixtures('cli_case_config')
+@pytest.mark.hf_turbomind_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model', get_turbomind_model_list(tp_num=4))
+def test_hf_turbomind_chat_tp4(config, model, cli_case_config, worker_id):
+ usercase = 'chat_testcase'
+ result, chat_log, msg = hf_command_line_test(
+ config,
+ usercase,
+ cli_case_config.get(usercase),
+ model,
+ 'turbomind',
+ cuda_prefix=get_cuda_prefix_by_workerid(worker_id, tp_num=4))
+
+ if chat_log is not None:
+ allure.attach.file(chat_log,
+ attachment_type=allure.attachment_type.TEXT)
+
+ assert result, msg
+
+
@pytest.mark.order(10)
@pytest.mark.usefixtures('cli_case_config')
@pytest.mark.hf_turbomind_chat
diff --git a/autotest/tools/chat/test_command_chat_workspace.py b/autotest/tools/chat/test_command_chat_workspace.py
index a16d4e32f6..415a1c528c 100644
--- a/autotest/tools/chat/test_command_chat_workspace.py
+++ b/autotest/tools/chat/test_command_chat_workspace.py
@@ -9,7 +9,8 @@
@pytest.mark.usefixtures('cli_case_config')
@pytest.mark.command_chat
@pytest.mark.gpu_num_1
-@pytest.mark.parametrize('model', get_turbomind_model_list(tp_num=1))
+@pytest.mark.parametrize('model',
+ get_turbomind_model_list(tp_num=1, is_converted=True))
def test_workspace_chat_tp1(config, cli_case_config, model, worker_id):
usercase = 'chat_testcase'
# cannot convert with rop-scale params, so the case should be skipped
@@ -32,7 +33,8 @@ def test_workspace_chat_tp1(config, cli_case_config, model, worker_id):
@pytest.mark.usefixtures('cli_case_config')
@pytest.mark.command_chat
@pytest.mark.gpu_num_2
-@pytest.mark.parametrize('model', get_turbomind_model_list(tp_num=2))
+@pytest.mark.parametrize('model',
+ get_turbomind_model_list(tp_num=2, is_converted=True))
def test_workspace_chat_tp2(config, cli_case_config, model, worker_id):
usercase = 'chat_testcase'
result, chat_log, msg = command_line_test(
@@ -54,7 +56,8 @@ def test_workspace_chat_tp2(config, cli_case_config, model, worker_id):
@pytest.mark.gpu_num_1
@pytest.mark.parametrize('model',
get_turbomind_model_list(tp_num=1,
- model_type='base_model'))
+ model_type='base_model',
+ is_converted=True))
def test_workspace_base_tp1(config, cli_case_config, model, worker_id):
usercase = 'base_testcase'
result, chat_log, msg = command_line_test(
@@ -76,7 +79,8 @@ def test_workspace_base_tp1(config, cli_case_config, model, worker_id):
@pytest.mark.gpu_num_2
@pytest.mark.parametrize('model',
get_turbomind_model_list(tp_num=2,
- model_type='base_model'))
+ model_type='base_model',
+ is_converted=True))
def test_workspace_base_tp2(config, cli_case_config, model, worker_id):
usercase = 'base_testcase'
result, chat_log, msg = command_line_test(
diff --git a/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py b/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py
index 58674fa173..c0348ec500 100644
--- a/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py
+++ b/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py
@@ -56,6 +56,32 @@ def test_pipeline_chat_pytorch_tp2(config, common_case_config, model,
worker_id)
+@pytest.mark.order(6)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.pipeline_chat_pytorch
+@pytest.mark.gpu_num_4
+@pytest.mark.flaky(reruns=0)
+@pytest.mark.parametrize('model',
+ get_torch_model_list(tp_num=4, exclude_dup=True))
+def test_pipeline_chat_pytorch_tp4(config, common_case_config, model,
+ worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_chat_test,
+ args=(config, common_case_config, model,
+ 'pytorch', worker_id))
+ p.start()
+ p.join()
+
+ # assert script
+ assert_pipeline_chat_log(config, common_case_config, model, 'pytorch',
+ worker_id)
+
+
@pytest.mark.order(6)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.pipeline_chat
@@ -109,6 +135,34 @@ def test_pipeline_chat_kvint4_tp2(config, common_case_config, model,
'pytorch-kvint', worker_id)
+@pytest.mark.order(6)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.flaky(reruns=0)
+@pytest.mark.parametrize('model',
+ get_torch_model_list(tp_num=4,
+ quant_policy=4,
+ exclude_dup=True))
+def test_pipeline_chat_kvint4_tp4(config, common_case_config, model,
+ worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_chat_test,
+ args=(config, common_case_config, model,
+ 'pytorch-kvint', worker_id, {
+ 'quant_policy': 4
+ }))
+ p.start()
+ p.join()
+ assert_pipeline_chat_log(config, common_case_config, model,
+ 'pytorch-kvint', worker_id)
+
+
@pytest.mark.order(6)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.pipeline_chat
@@ -162,6 +216,34 @@ def test_pipeline_chat_kvint8_tp2(config, common_case_config, model,
'pytorch-kvint', worker_id)
+@pytest.mark.order(6)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.flaky(reruns=0)
+@pytest.mark.parametrize('model',
+ get_torch_model_list(tp_num=4,
+ quant_policy=8,
+ exclude_dup=True))
+def test_pipeline_chat_kvint8_tp4(config, common_case_config, model,
+ worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_chat_test,
+ args=(config, common_case_config, model,
+ 'pytorch-kvint', worker_id, {
+ 'quant_policy': 8
+ }))
+ p.start()
+ p.join()
+ assert_pipeline_chat_log(config, common_case_config, model,
+ 'pytorch-kvint', worker_id)
+
+
@pytest.mark.order(6)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.pipeline_chat_pytorch
diff --git a/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py b/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py
index 8403ced94f..8735b8e937 100644
--- a/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py
+++ b/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py
@@ -34,6 +34,27 @@ def test_pipeline_chat_tp2(config, model, worker_id):
if 'gw' in worker_id:
os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
tp_num=2)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_vl_chat_test,
+ args=(config, model, BACKEND, worker_id))
+ p.start()
+ p.join()
+ assert_pipeline_vl_chat_log(config, model, worker_id)
+
+
+@pytest.mark.order(6)
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model',
+ get_torch_model_list(tp_num=4, model_type='vl_model'))
+def test_pipeline_chat_tp4(config, model, worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
spawn_context = get_context('spawn')
p = spawn_context.Process(target=run_pipeline_vl_chat_test,
args=(config, model, BACKEND, worker_id))
@@ -71,6 +92,29 @@ def test_pipeline_chat_kvint4_tp2(config, model, worker_id):
if 'gw' in worker_id:
os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
tp_num=2)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_vl_chat_test,
+ args=(config, model, BACKEND, worker_id, 4))
+ p.start()
+ p.join()
+ assert_pipeline_vl_chat_log(config, model, worker_id)
+
+
+@pytest.mark.order(6)
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model',
+ get_torch_model_list(tp_num=4,
+ quant_policy=4,
+ model_type='vl_model'))
+def test_pipeline_chat_kvint4_tp4(config, model, worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
spawn_context = get_context('spawn')
p = spawn_context.Process(target=run_pipeline_vl_chat_test,
args=(config, model, BACKEND, worker_id, 4))
@@ -108,6 +152,29 @@ def test_pipeline_chat_kvint8_tp2(config, model, worker_id):
if 'gw' in worker_id:
os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
tp_num=2)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_vl_chat_test,
+ args=(config, model, BACKEND, worker_id, 8))
+ p.start()
+ p.join()
+ assert_pipeline_vl_chat_log(config, model, worker_id)
+
+
+@pytest.mark.order(6)
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model',
+ get_torch_model_list(tp_num=4,
+ quant_policy=8,
+ model_type='vl_model'))
+def test_pipeline_chat_kvint8_tp4(config, model, worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
spawn_context = get_context('spawn')
p = spawn_context.Process(target=run_pipeline_vl_chat_test,
args=(config, model, BACKEND, worker_id, 8))
diff --git a/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py b/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py
index d1865175cf..58eab0de76 100644
--- a/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py
+++ b/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py
@@ -48,6 +48,28 @@ def test_pipeline_chat_tp2(config, common_case_config, model, worker_id):
worker_id)
+@pytest.mark.order(6)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.flaky(reruns=0)
+@pytest.mark.parametrize('model', get_all_model_list(tp_num=4))
+def test_pipeline_chat_tp4(config, common_case_config, model, worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_chat_test,
+ args=(config, common_case_config, model,
+ 'turbomind', worker_id))
+ p.start()
+ p.join()
+ assert_pipeline_chat_log(config, common_case_config, model, 'turbomind',
+ worker_id)
+
+
@pytest.mark.order(6)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.pipeline_chat
@@ -95,6 +117,31 @@ def test_pipeline_chat_kvint4_tp2(config, common_case_config, model,
'turbomind-kvint', worker_id)
+@pytest.mark.order(6)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.flaky(reruns=0)
+@pytest.mark.parametrize('model', get_all_model_list(tp_num=4, quant_policy=4))
+def test_pipeline_chat_kvint4_tp4(config, common_case_config, model,
+ worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_chat_test,
+ args=(config, common_case_config, model,
+ 'turbomind-kvint', worker_id, {
+ 'quant_policy': 4
+ }))
+ p.start()
+ p.join()
+ assert_pipeline_chat_log(config, common_case_config, model,
+ 'turbomind-kvint', worker_id)
+
+
@pytest.mark.order(6)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.pipeline_chat
@@ -142,6 +189,31 @@ def test_pipeline_chat_kvint8_tp2(config, common_case_config, model,
'turbomind-kvint', worker_id)
+@pytest.mark.order(6)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.flaky(reruns=0)
+@pytest.mark.parametrize('model', get_all_model_list(tp_num=4, quant_policy=8))
+def test_pipeline_chat_kvint8_tp4(config, common_case_config, model,
+ worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_chat_test,
+ args=(config, common_case_config, model,
+ 'turbomind-kvint', worker_id, {
+ 'quant_policy': 8
+ }))
+ p.start()
+ p.join()
+ assert_pipeline_chat_log(config, common_case_config, model,
+ 'turbomind-kvint', worker_id)
+
+
@pytest.mark.order(6)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.pipeline_chat
diff --git a/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py b/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py
index 8c845fa77a..c62bfc5e8e 100644
--- a/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py
+++ b/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py
@@ -34,6 +34,27 @@ def test_pipeline_chat_tp2(config, model, worker_id):
if 'gw' in worker_id:
os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
tp_num=2)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_vl_chat_test,
+ args=(config, model, BACKEND, worker_id))
+ p.start()
+ p.join()
+ assert_pipeline_vl_chat_log(config, model, worker_id)
+
+
+@pytest.mark.order(6)
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model',
+ get_all_model_list(tp_num=4, model_type='vl_model'))
+def test_pipeline_chat_tp4(config, model, worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
spawn_context = get_context('spawn')
p = spawn_context.Process(target=run_pipeline_vl_chat_test,
args=(config, model, BACKEND, worker_id))
@@ -71,6 +92,29 @@ def test_pipeline_chat_kvint4_tp2(config, model, worker_id):
if 'gw' in worker_id:
os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
tp_num=2)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_vl_chat_test,
+ args=(config, model, BACKEND, worker_id, 4))
+ p.start()
+ p.join()
+ assert_pipeline_vl_chat_log(config, model, worker_id)
+
+
+@pytest.mark.order(6)
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model',
+ get_all_model_list(tp_num=4,
+ quant_policy=4,
+ model_type='vl_model'))
+def test_pipeline_chat_kvint4_tp4(config, model, worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
spawn_context = get_context('spawn')
p = spawn_context.Process(target=run_pipeline_vl_chat_test,
args=(config, model, BACKEND, worker_id, 4))
@@ -108,6 +152,29 @@ def test_pipeline_chat_kvint8_tp2(config, model, worker_id):
if 'gw' in worker_id:
os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
tp_num=2)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_vl_chat_test,
+ args=(config, model, BACKEND, worker_id, 8))
+ p.start()
+ p.join()
+ assert_pipeline_vl_chat_log(config, model, worker_id)
+
+
+@pytest.mark.order(6)
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model',
+ get_all_model_list(tp_num=4,
+ quant_policy=8,
+ model_type='vl_model'))
+def test_pipeline_chat_kvint8_tp4(config, model, worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
spawn_context = get_context('spawn')
p = spawn_context.Process(target=run_pipeline_vl_chat_test,
args=(config, model, BACKEND, worker_id, 8))
diff --git a/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py b/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py
index fc95e288ca..bc0ea3996a 100644
--- a/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py
+++ b/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py
@@ -60,6 +60,23 @@ def test_restful_chat_tp2(config, common_case_config, worker_id):
port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.restful_api_pytorch
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getModelList(tp_num=4),
+ indirect=True)
+def test_restful_chat_tp4(config, common_case_config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_all_step(config, common_case_config)
+ else:
+ run_all_step(config,
+ common_case_config,
+ worker_id=worker_id,
+ port=DEFAULT_PORT + get_workerid(worker_id))
+
+
def getKvintModelList(tp_num, quant_policy):
return [{
'model': item,
@@ -104,6 +121,23 @@ def test_restful_chat_kvint4_tp2(config, common_case_config, worker_id):
port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.restful_api
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=4),
+ indirect=True)
+def test_restful_chat_kvint4_tp4(config, common_case_config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_all_step(config, common_case_config)
+ else:
+ run_all_step(config,
+ common_case_config,
+ worker_id=worker_id,
+ port=DEFAULT_PORT + get_workerid(worker_id))
+
+
@pytest.mark.order(7)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.restful_api
@@ -138,6 +172,23 @@ def test_restful_chat_kvint8_tp2(config, common_case_config, worker_id):
port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.restful_api
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=8),
+ indirect=True)
+def test_restful_chat_kvint8_tp4(config, common_case_config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_all_step(config, common_case_config)
+ else:
+ run_all_step(config,
+ common_case_config,
+ worker_id=worker_id,
+ port=DEFAULT_PORT + get_workerid(worker_id))
+
+
@pytest.mark.order(7)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.restful_api
diff --git a/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py b/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py
index bf20c45e6e..cc85d35d09 100644
--- a/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py
+++ b/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py
@@ -53,6 +53,19 @@ def test_restful_chat_tp2(config, worker_id):
run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.restful_api_vl
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getModelList(tp_num=4),
+ indirect=True)
+def test_restful_chat_tp4(config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_vl_testcase(config)
+ else:
+ run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+
+
def getKvintModelList(tp_num, quant_policy: int = None):
return [{
'model': item,
@@ -89,6 +102,19 @@ def test_restful_chat_kvint4_tp2(config, worker_id):
run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.restful_api_vl
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=4),
+ indirect=True)
+def test_restful_chat_kvint4_tp4(config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_vl_testcase(config)
+ else:
+ run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+
+
@pytest.mark.order(7)
@pytest.mark.restful_api_vl
@pytest.mark.gpu_num_1
@@ -113,3 +139,16 @@ def test_restful_chat_kvint8_tp2(config, worker_id):
run_vl_testcase(config)
else:
run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+
+
+@pytest.mark.order(7)
+@pytest.mark.restful_api_vl
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=8),
+ indirect=True)
+def test_restful_chat_kvint8_tp4(config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_vl_testcase(config)
+ else:
+ run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
diff --git a/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py b/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py
index 1c9131b32e..435ffc4ae3 100644
--- a/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py
+++ b/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py
@@ -60,6 +60,23 @@ def test_restful_chat_tp2(config, common_case_config, worker_id):
port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.restful_api
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getModelList(tp_num=4),
+ indirect=True)
+def test_restful_chat_tp4(config, common_case_config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_all_step(config, common_case_config)
+ else:
+ run_all_step(config,
+ common_case_config,
+ worker_id=worker_id,
+ port=DEFAULT_PORT + get_workerid(worker_id))
+
+
def getKvintModelList(tp_num, quant_policy):
return [{
'model': item,
@@ -103,6 +120,23 @@ def test_restful_chat_kvint4_tp2(config, common_case_config, worker_id):
port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.restful_api
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=4),
+ indirect=True)
+def test_restful_chat_kvint4_tp4(config, common_case_config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_all_step(config, common_case_config)
+ else:
+ run_all_step(config,
+ common_case_config,
+ worker_id=worker_id,
+ port=DEFAULT_PORT + get_workerid(worker_id))
+
+
@pytest.mark.order(7)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.restful_api
@@ -137,6 +171,23 @@ def test_restful_chat_kvint8_tp2(config, common_case_config, worker_id):
port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.restful_api
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=8),
+ indirect=True)
+def test_restful_chat_kvint8_tp4(config, common_case_config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_all_step(config, common_case_config)
+ else:
+ run_all_step(config,
+ common_case_config,
+ worker_id=worker_id,
+ port=DEFAULT_PORT + get_workerid(worker_id))
+
+
@pytest.mark.order(7)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.restful_api
diff --git a/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py b/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py
index 641f2f760f..bbb8718366 100644
--- a/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py
+++ b/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py
@@ -53,6 +53,19 @@ def test_restful_chat_tp2(config, worker_id):
run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.restful_api_vl
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getModelList(tp_num=4),
+ indirect=True)
+def test_restful_chat_tp4(config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_vl_testcase(config)
+ else:
+ run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+
+
def getKvintModelList(tp_num, quant_policy: int = None):
return [{
'model': item,
@@ -89,6 +102,19 @@ def test_restful_chat_kvint4_tp2(config, worker_id):
run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.restful_api_vl
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=4),
+ indirect=True)
+def test_restful_chat_kvint4_tp4(config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_vl_testcase(config)
+ else:
+ run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+
+
@pytest.mark.order(7)
@pytest.mark.restful_api_vl
@pytest.mark.gpu_num_1
@@ -113,3 +139,16 @@ def test_restful_chat_kvint8_tp2(config, worker_id):
run_vl_testcase(config)
else:
run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+
+
+@pytest.mark.order(7)
+@pytest.mark.restful_api_vl
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=8),
+ indirect=True)
+def test_restful_chat_kvint8_tp4(config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_vl_testcase(config)
+ else:
+ run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
diff --git a/autotest/tools/restful/test_restful_chat_workspace.py b/autotest/tools/restful/test_restful_chat_workspace.py
index 798a43d7b0..cf69007cca 100644
--- a/autotest/tools/restful/test_restful_chat_workspace.py
+++ b/autotest/tools/restful/test_restful_chat_workspace.py
@@ -23,8 +23,7 @@ def getModelList(tp_num):
'model': item,
'cuda_prefix': None,
'tp_num': tp_num
- } for item in get_turbomind_model_list(tp_num)
- if item not in ('deepseek-ai/deepseek-coder-1.3b-instruct')]
+ } for item in get_turbomind_model_list(tp_num, is_converted=True)]
@pytest.mark.order(7)
diff --git a/autotest/utils/config_utils.py b/autotest/utils/config_utils.py
index 24b4a3f8cd..87d5d73f10 100644
--- a/autotest/utils/config_utils.py
+++ b/autotest/utils/config_utils.py
@@ -9,17 +9,33 @@
def get_turbomind_model_list(tp_num: int = None,
model_type: str = 'chat_model',
- quant_policy: int = None):
+ quant_policy: int = None,
+ is_converted: bool = False):
config = get_config()
if quant_policy is None:
- case_list = copy.deepcopy(config.get('turbomind_' + model_type))
+ if is_converted:
+ case_list = [
+ x for x in copy.deepcopy(config.get('turbomind_' + model_type))
+ if x not in config.get('turbomind_quatization').get(
+ 'no_converted')
+ ]
+ else:
+ case_list = copy.deepcopy(config.get('turbomind_' + model_type))
else:
- case_list = [
- x for x in config.get('turbomind_' + model_type)
- if x not in config.get('turbomind_quatization').get(
- 'no_kvint' + str(quant_policy))
- ]
+ if is_converted:
+ case_list = [
+ x for x in config.get('turbomind_' + model_type)
+ if x not in config.get('turbomind_quatization').get(
+ 'no_kvint' + str(quant_policy) and x not in config.get(
+ 'turbomind_quatization').get('no_converted'))
+ ]
+ else:
+ case_list = [
+ x for x in config.get('turbomind_' + model_type)
+ if x not in config.get('turbomind_quatization').get(
+ 'no_kvint' + str(quant_policy))
+ ]
quatization_case_config = config.get('turbomind_quatization')
for key in config.get('turbomind_' + model_type):
@@ -97,7 +113,7 @@ def get_all_model_list(tp_num: int = None,
model_type=model_type):
if case not in case_list:
case_list.append(case)
- return [x for x in case_list if 'w8a8' not in x]
+ return case_list
def get_quantization_model_list(type):
@@ -202,6 +218,7 @@ def get_benchmark_model_list(tp_num,
else:
case_list_base = config.get('benchmark_model')
quatization_case_config = config.get('turbomind_quatization')
+ pytorch_quatization_case_config = config.get('pytorch_quatization')
case_list = copy.deepcopy(case_list_base)
for key in case_list_base:
@@ -210,6 +227,12 @@ def get_benchmark_model_list(tp_num,
'no_awq') and not is_quantization_model(key):
case_list.append(key + '-inner-4bits')
+ for key in case_list_base:
+ if key in config.get('pytorch_chat_model'
+ ) and key in pytorch_quatization_case_config.get(
+ 'w8a8') and not is_quantization_model(key):
+ case_list.append(key + '-inner-w8a8')
+
model_list = [
item for item in case_list if get_tp_num(config, item) == tp_num
]
@@ -228,15 +251,18 @@ def get_benchmark_model_list(tp_num,
'backend': 'pytorch',
'tp_num': tp_num
} for item in model_list if '4bits' not in item and (
- item in config.get('pytorch_chat_model') or tp_num == 4)]
+ item.replace('-inner-w8a8', '') in config.get('pytorch_chat_model')
+ or tp_num == 4)]
for kvint in kvint_list:
result += [{
'model': item,
'backend': 'turbomind',
'quant_policy': kvint,
'tp_num': tp_num
- } for item in model_list if item.replace('-inner-4bits', '') in
- config.get('turbomind_chat_model')]
+ } for item in model_list if item.replace(
+ '-inner-4bits', '') in config.get('turbomind_chat_model')
+ and item.replace('-inner-4bits', '') not in
+ quatization_case_config.get('no_kvint' + str(kvint))]
return result
diff --git a/autotest/utils/pipeline_chat.py b/autotest/utils/pipeline_chat.py
index 023e4ac142..8f03e4e406 100644
--- a/autotest/utils/pipeline_chat.py
+++ b/autotest/utils/pipeline_chat.py
@@ -277,14 +277,14 @@ def assert_pipeline_single_element(output,
return result
-PIC1 = 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg' # noqa E501
-PIC2 = 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg' # noqa E501
-PIC_BEIJING = 'https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg' # noqa E501
-PIC_CHONGQING = 'https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg' # noqa E501
-PIC_REDPANDA = 'https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg' # noqa E501
-PIC_PANDA = 'https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg' # noqa E501
-DESC = 'What are the similarities and differences between these two images.' # noqa E501
-DESC_ZH = '两张图有什么相同和不同的地方.' # noqa E501
+PIC1 = 'tiger.jpeg'
+PIC2 = 'human-pose.jpg'
+PIC_BEIJING = 'Beijing_Small.jpeg'
+PIC_CHONGQING = 'Chongqing_Small.jpeg'
+PIC_REDPANDA = 'redpanda.jpg'
+PIC_PANDA = 'panda.jpg'
+DESC = 'What are the similarities and differences between these two images.'
+DESC_ZH = '两张图有什么相同和不同的地方.'
def run_pipeline_vl_chat_test(config,
@@ -296,6 +296,7 @@ def run_pipeline_vl_chat_test(config,
tp = get_tp_num(config, model_case)
model_path = config.get('model_path')
hf_path = model_path + '/' + model_case
+ resource_path = config.get('resource_path')
if 'pytorch' in backend:
backend_config = PytorchEngineConfig(tp=tp, session_len=8192)
@@ -320,7 +321,7 @@ def run_pipeline_vl_chat_test(config,
'pipeline_vl_chat_' + model_case.split('/')[1] + worker_id + '.log')
file = open(pipeline_chat_log, 'w')
- image = load_image(PIC1)
+ image = load_image(f'{resource_path}/{PIC1}')
if 'deepseek' in model_case:
prompt = f'describe this image{IMAGE_TOKEN}'
@@ -352,7 +353,7 @@ def run_pipeline_vl_chat_test(config,
}, {
'type': 'image_url',
'image_url': {
- 'url': PIC1
+ 'url': f'{resource_path}/{PIC1}'
}
}]
}]
@@ -362,7 +363,7 @@ def run_pipeline_vl_chat_test(config,
', reason: OpenAI format example: tiger not in ' +
response.text + '\n')
- image_urls = [PIC2, PIC1]
+ image_urls = [f'{resource_path}/{PIC2}', f'{resource_path}/{PIC1}']
images = [load_image(img_url) for img_url in image_urls]
response = pipe((prompt, images))
result = 'tiger' in response.text.lower() or 'ski' in response.text.lower(
@@ -371,7 +372,7 @@ def run_pipeline_vl_chat_test(config,
', reason: Multi-images example: tiger or ski not in ' +
response.text + '\n')
- image_urls = [PIC2, PIC1]
+ image_urls = [f'{resource_path}/{PIC2}', f'{resource_path}/{PIC1}']
prompts = [(prompt, load_image(img_url)) for img_url in image_urls]
response = pipe(prompts)
result = ('ski' in response[0].text.lower()
@@ -382,7 +383,7 @@ def run_pipeline_vl_chat_test(config,
', reason: Batch example: ski or tiger not in ' +
str(response) + '\n')
- image = load_image(PIC2)
+ image = load_image(f'{resource_path}/{PIC2}')
sess = pipe.chat((prompt, image))
result = 'ski' in sess.response.text.lower(
) or '滑雪' in sess.response.text.lower()
@@ -397,12 +398,12 @@ def run_pipeline_vl_chat_test(config,
sess.response.text + '\n')
if 'internvl' in model_case.lower():
- internvl_vl_testcase(config, pipe, file)
- internvl_vl_testcase(config, pipe, file, 'cn')
+ internvl_vl_testcase(pipe, file, resource_path)
+ internvl_vl_testcase(pipe, file, resource_path, 'cn')
if 'minicpm' in model_case.lower():
- MiniCPM_vl_testcase(config, pipe, file)
+ MiniCPM_vl_testcase(pipe, file, resource_path)
if 'qwen' in model_case.lower():
- Qwen_vl_testcase(config, pipe, file)
+ Qwen_vl_testcase(pipe, file, resource_path)
file.close()
@@ -410,7 +411,7 @@ def run_pipeline_vl_chat_test(config,
torch.cuda.empty_cache()
-def internvl_vl_testcase(config, pipe, file, lang='en'):
+def internvl_vl_testcase(pipe, file, resource_path, lang='en'):
if lang == 'cn':
description = DESC_ZH
else:
@@ -422,9 +423,11 @@ def internvl_vl_testcase(config, pipe, file, lang='en'):
dict(type='text',
text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\n{description}'),
dict(type='image_url',
- image_url=dict(max_dynamic_patch=12, url=PIC_REDPANDA)),
+ image_url=dict(max_dynamic_patch=12,
+ url=f'{resource_path}/{PIC_REDPANDA}')),
dict(type='image_url',
- image_url=dict(max_dynamic_patch=12, url=PIC_PANDA))
+ image_url=dict(max_dynamic_patch=12,
+ url=f'{resource_path}/{PIC_PANDA}'))
])
]
response = pipe(messages)
@@ -452,9 +455,11 @@ def internvl_vl_testcase(config, pipe, file, lang='en'):
+ # noqa E251,E501
description),
dict(type='image_url',
- image_url=dict(max_dynamic_patch=12, url=PIC_REDPANDA)),
+ image_url=dict(max_dynamic_patch=12,
+ url=f'{resource_path}/{PIC_REDPANDA}')),
dict(type='image_url',
- image_url=dict(max_dynamic_patch=12, url=PIC_PANDA))
+ image_url=dict(max_dynamic_patch=12,
+ url=f'{resource_path}/{PIC_PANDA}'))
])
]
response = pipe(messages)
@@ -501,8 +506,7 @@ def load_video(video_path, bound=None, num_segments=32):
imgs.append(img)
return imgs
- resource_path = config.get('resource_path')
- video_path = resource_path + '/red-panda.mp4'
+ video_path = f'{resource_path}/red-panda.mp4'
imgs = load_video(video_path, num_segments=8)
question = ''
@@ -546,14 +550,16 @@ def load_video(video_path, bound=None, num_segments=32):
response.text + '\n')
-def llava_vl_testcase(config, pipe, file):
+def llava_vl_testcase(pipe, file, resource_path):
# multi-image multi-round conversation, combined images
messages = [
dict(role='user',
content=[
dict(type='text', text='Describe the two images in detail.'),
- dict(type='image_url', image_url=dict(url=PIC_BEIJING)),
- dict(type='image_url', image_url=dict(url=PIC_CHONGQING))
+ dict(type='image_url',
+ image_url=dict(url=f'{resource_path}/{PIC_BEIJING}')),
+ dict(type='image_url',
+ image_url=dict(url=f'{resource_path}/{PIC_CHONGQING}'))
])
]
response = pipe(messages)
@@ -575,16 +581,18 @@ def llava_vl_testcase(config, pipe, file):
response.text + '\n')
-def MiniCPM_vl_testcase(config, pipe, file):
+def MiniCPM_vl_testcase(pipe, file, resource_path):
# Chat with multiple images
messages = [
dict(role='user',
content=[
dict(type='text', text='Describe the two images in detail.'),
dict(type='image_url',
- image_url=dict(max_slice_nums=9, url=PIC_REDPANDA)),
+ image_url=dict(max_slice_nums=9,
+ url=f'{resource_path}/{PIC_REDPANDA}')),
dict(type='image_url',
- image_url=dict(max_slice_nums=9, url=PIC_PANDA))
+ image_url=dict(max_slice_nums=9,
+ url=f'{resource_path}/{PIC_PANDA}'))
])
]
response = pipe(messages)
@@ -602,27 +610,27 @@ def MiniCPM_vl_testcase(config, pipe, file):
response.text + '\n')
# In-context few-shot learning
- EXAMPLE1 = 'https://github.com/user-attachments/assets/405d9147-95f6-4f78-8879-606a0aed6707' # noqa E251,E501
- EXAMPLE2 = 'https://github.com/user-attachments/assets/9f2c6ed9-2aa5-4189-9c4f-0b9753024ba1' # noqa E251,E501
- EXAMPLE3 = 'https://github.com/user-attachments/assets/f335b507-1957-4c22-84ae-ed69ff79df38' # noqa E251,E501
question = 'production date'
messages = [
dict(role='user',
content=[
dict(type='text', text=question),
- dict(type='image_url', image_url=dict(url=EXAMPLE1)),
+ dict(type='image_url',
+ image_url=dict(url=f'{resource_path}/data1.jpeg')),
]),
dict(role='assistant', content='2021.08.29'),
dict(role='user',
content=[
dict(type='text', text=question),
- dict(type='image_url', image_url=dict(url=EXAMPLE2)),
+ dict(type='image_url',
+ image_url=dict(url=f'{resource_path}/data2.jpeg')),
]),
dict(role='assistant', content='1999.05.15'),
dict(role='user',
content=[
dict(type='text', text=question),
- dict(type='image_url', image_url=dict(url=EXAMPLE3)),
+ dict(type='image_url',
+ image_url=dict(url=f'{resource_path}/data3.jpeg')),
])
]
response = pipe(messages)
@@ -651,8 +659,7 @@ def uniform_sample(length, n):
print('num frames:', len(frames))
return frames
- resource_path = config.get('resource_path')
- video_path = resource_path + '/red-panda.mp4'
+ video_path = f'{resource_path}red-panda.mp4'
frames = encode_video(video_path)
question = 'Describe the video'
@@ -675,14 +682,16 @@ def uniform_sample(length, n):
'\n')
-def Qwen_vl_testcase(config, pipe, file):
+def Qwen_vl_testcase(pipe, file, resource_path):
# multi-image multi-round conversation, combined images
messages = [
dict(role='user',
content=[
dict(type='text', text='Describe the two images in detail.'),
- dict(type='image_url', image_url=dict(url=PIC_BEIJING)),
- dict(type='image_url', image_url=dict(url=PIC_CHONGQING))
+ dict(type='image_url',
+ image_url=dict(url=f'{resource_path}/{PIC_BEIJING}')),
+ dict(type='image_url',
+ image_url=dict(url=f'{resource_path}/{PIC_CHONGQING}'))
])
]
response = pipe(messages)
@@ -713,11 +722,11 @@ def Qwen_vl_testcase(config, pipe, file):
dict(type='image_url',
image_url=dict(min_pixels=min_pixels,
max_pixels=max_pixels,
- url=PIC_BEIJING)),
+ url=f'{resource_path}/{PIC_BEIJING}')),
dict(type='image_url',
image_url=dict(min_pixels=min_pixels,
max_pixels=max_pixels,
- url=PIC_CHONGQING))
+ url=f'{resource_path}/{PIC_CHONGQING}'))
])
]
response = pipe(messages)
diff --git a/benchmark/profile_generation.py b/benchmark/profile_generation.py
index 952de5d9f7..6c33b8bc4b 100644
--- a/benchmark/profile_generation.py
+++ b/benchmark/profile_generation.py
@@ -1,11 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
+import asyncio
import csv
import os
import time
from dataclasses import dataclass
-from queue import Queue
-from threading import Thread
from typing import List, Union
import numpy as np
@@ -24,8 +23,9 @@
os.environ['TM_LOG_LEVEL'] = 'ERROR'
-def infer(model, session_id: int, input_ids: List,
- gen_config: GenerationConfig, test_round: int, que: Queue):
+async def infer(model, session_id: int, input_ids: List,
+ gen_config: GenerationConfig, test_round: int,
+ que: asyncio.Queue):
if session_id == 1:
pbar = tqdm(total=test_round)
chatbot = model.create_instance()
@@ -47,12 +47,12 @@ def infer(model, session_id: int, input_ids: List,
The time elapsing in this iteration `now-prev` is set to the latency of first token of
the 5 tokens, i.e. `token_latency_stats[0]`, and `token_latency_stats[1:4]` is set 0`
""" # noqa: E501
- for outputs in chatbot.stream_infer(session_id,
- input_ids,
- gen_config=gen_config,
- sequence_start=True,
- sequence_end=True,
- stream_output=True):
+ async for outputs in chatbot.async_stream_infer(session_id,
+ input_ids,
+ gen_config=gen_config,
+ sequence_start=True,
+ sequence_end=True,
+ stream_output=True):
n_token = outputs.num_token
now = time.perf_counter()
if n_prev_token != n_token:
@@ -61,7 +61,7 @@ def infer(model, session_id: int, input_ids: List,
prev = now
# for pytorch engine to restart a session
if hasattr(chatbot, 'end'):
- chatbot.end(session_id)
+ await chatbot.async_end(session_id)
if session_id == 1:
pbar.update(1)
@@ -69,39 +69,42 @@ def infer(model, session_id: int, input_ids: List,
f'Error. session_id({session_id}) request {output_seqlen} ' \
f'tokens, but generate {n_token} tokens'
stats.append(token_latency_stats[:output_seqlen])
- que.put((session_id, stats))
+ await que.put((session_id, stats))
def warmup(model, concurrency: int, input_ids: List[int], warmup_round: int,
- gen_config: GenerationConfig):
+ gen_config: GenerationConfig, event_loop: asyncio.BaseEventLoop):
if not warmup_round:
return
print('start to warmup ...')
- def _infer(model, session_id):
+ async def _infer(model, session_id):
chatbot = model.create_instance()
for _ in range(warmup_round):
- for _ in chatbot.stream_infer(session_id,
- input_ids=input_ids,
- sequence_start=True,
- sequence_end=True,
- ignore_eos=True,
- gen_config=gen_config):
+ async for _ in chatbot.async_stream_infer(session_id,
+ input_ids=input_ids,
+ sequence_start=True,
+ sequence_end=True,
+ ignore_eos=True,
+ gen_config=gen_config):
continue
# for pytorch engine to restart a session
if hasattr(chatbot, 'end'):
- chatbot.end(session_id)
+ await chatbot.async_end(session_id)
_start = time.perf_counter()
- procs = []
+
+ # start threads
+ tasks = []
for i in range(concurrency):
- proc = Thread(target=_infer, args=(model, i + 1), daemon=True)
- procs.append(proc)
- proc.start()
+ task = _infer(model, i + 1)
+ tasks.append(task)
+
+ async def _gather_tasks(tasks):
+ return await asyncio.gather(*tasks)
- for proc in procs:
- proc.join()
+ event_loop.run_until_complete(_gather_tasks(tasks))
_end = time.perf_counter()
print(f'end warmup, elapsed time: {round(_end - _start, 2)}s')
@@ -125,31 +128,34 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int,
from lmdeploy.pytorch.engine import Engine
tm_model = Engine(model_path, engine_config)
+ event_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(event_loop)
+
# make up a dummy `input_ids` with the length of `input_seqlen` exactly
assert input_seqlen > 0, 'input_seqlen should > 0'
input_ids = np.random.randint(low=0, high=101, size=input_seqlen).tolist()
- warmup(tm_model, concurrency, input_ids, warmup_round, gen_config)
+ warmup(tm_model, concurrency, input_ids, warmup_round, gen_config,
+ event_loop)
- que = Queue()
- procs = []
+ que = asyncio.Queue()
_start = time.perf_counter()
+ tasks = []
for i in range(concurrency):
- proc = Thread(target=infer,
- args=(tm_model, i + 1, input_ids, gen_config, test_round,
- que))
- procs.append(proc)
- proc.start()
+ task = infer(tm_model, i + 1, input_ids, gen_config, test_round, que)
+ tasks.append(task)
+
+ async def _gather_tasks(tasks):
+ return await asyncio.gather(*tasks)
- for proc in procs:
- proc.join()
+ event_loop.run_until_complete(_gather_tasks(tasks))
_end = time.perf_counter()
elapsed_time = _end - _start
token_latency_stats = []
while not que.empty():
- _, _stats = que.get()
+ _, _stats = que.get_nowait()
token_latency_stats += _stats
# The shape is [concurrency*test_round, output_seqlen]
@@ -426,7 +432,6 @@ def main():
block_size=args.cache_block_seq_len,
session_len=session_len,
tp=args.tp,
- thread_safe=True,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
dtype=args.dtype,
diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py
index 4f06fad4f9..291b1be9b8 100644
--- a/benchmark/profile_throughput.py
+++ b/benchmark/profile_throughput.py
@@ -345,12 +345,14 @@ def main():
requests = sample_requests(args.dataset, args.num_prompts,
engine.tokenizer)
- engine.process_request(requests,
- temperature=args.temperature,
- top_p=args.top_p,
- top_k=args.top_k,
- concurrency=args.concurrency,
- stream_output=True)
+ engine.process_request(
+ requests,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ top_k=args.top_k,
+ concurrency=args.concurrency
+ if args.concurrency < args.num_prompts else args.num_prompts,
+ stream_output=True)
if __name__ == '__main__':
diff --git a/docker/Dockerfile b/docker/Dockerfile
index caa58ee637..24b2b055da 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -10,9 +10,6 @@ FROM ${CUDA_VERSION} AS final
ARG PYTHON_VERSION=3.10
-ARG TORCH_VERSION=2.3.0
-ARG TORCHVISION_VERSION=0.18.0
-
RUN apt-get update -y && apt-get install -y software-properties-common wget vim git curl openssh-server ssh sudo &&\
curl https://sh.rustup.rs -sSf | sh -s -- -y &&\
add-apt-repository ppa:deadsnakes/ppa -y && apt-get update -y && apt-get install -y --no-install-recommends \
@@ -43,7 +40,6 @@ ENV LD_LIBRARY_PATH=/usr/local/nccl/lib:$LD_LIBRARY_PATH
RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install --upgrade pip setuptools==69.5.1 &&\
- python3 -m pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} --index-url https://download.pytorch.org/whl/${CUDA_VERSION_SHORT} &&\
python3 -m pip install cmake packaging wheel
ENV NCCL_LAUNCH_MODE=GROUP
@@ -54,7 +50,7 @@ COPY . /opt/lmdeploy
WORKDIR /opt/lmdeploy
RUN --mount=type=cache,target=/root/.cache/pip cd /opt/lmdeploy &&\
- python3 -m pip install -r requirements.txt &&\
+ python3 -m pip install -r requirements_cuda.txt --extra-index-url https://download.pytorch.org/whl/${CUDA_VERSION_SHORT} &&\
mkdir -p build && cd build &&\
sh ../generate.sh &&\
ninja -j$(nproc) && ninja install &&\
diff --git a/docker/Dockerfile_aarch64_ascend b/docker/Dockerfile_aarch64_ascend
index 1c9591197b..5ed842061c 100644
--- a/docker/Dockerfile_aarch64_ascend
+++ b/docker/Dockerfile_aarch64_ascend
@@ -121,5 +121,4 @@ COPY --from=copy_temp /tmp /opt/lmdeploy
WORKDIR /opt/lmdeploy
RUN --mount=type=cache,target=/root/.cache/pip \
- sed -i '/triton/d' requirements/runtime.txt && \
- pip3 install -v --no-build-isolation -e .
+ LMDEPLOY_TARGET_DEVICE=ascend pip3 install -v --no-build-isolation -e .
diff --git a/docs/en/advance/pytorch_multithread.md b/docs/en/advance/pytorch_multithread.md
new file mode 100644
index 0000000000..446e0fa769
--- /dev/null
+++ b/docs/en/advance/pytorch_multithread.md
@@ -0,0 +1,78 @@
+# PyTorchEngine Multithread
+
+We have removed `thread_safe` mode from PytorchEngine since [PR2907](https://github.com/InternLM/lmdeploy/pull/2907). We encourage users to achieve high concurrency by using **service API** or **coroutines** whenever possible, for example:
+
+```python
+import asyncio
+from lmdeploy import pipeline, PytorchEngineConfig
+
+event_loop = asyncio.new_event_loop()
+asyncio.set_event_loop(event_loop)
+
+model_path = 'Llama-3.2-1B-Instruct'
+pipe = pipeline(model_path, backend_config=PytorchEngineConfig())
+
+async def _gather_output():
+ tasks = [
+ pipe.async_batch_infer('Hakuna Matata'),
+ pipe.async_batch_infer('giraffes are heartless creatures'),
+ ]
+ return await asyncio.gather(*tasks)
+
+output = asyncio.run(_gather_output())
+print(output[0].text)
+print(output[1].text)
+```
+
+If you do need multithreading, it would be easy to warp it like below:
+
+```python
+import threading
+from queue import Queue
+import asyncio
+from lmdeploy import pipeline, PytorchEngineConfig
+
+model_path = 'Llama-3.2-1B-Instruct'
+
+
+async def _batch_infer(inque: Queue, outque: Queue, pipe):
+ while True:
+ if inque.empty():
+ await asyncio.sleep(0)
+ continue
+
+ input = inque.get_nowait()
+ output = await pipe.async_batch_infer(input)
+ outque.put(output)
+
+
+def server(inques, outques):
+ event_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(event_loop)
+ pipe = pipeline(model_path, backend_config=PytorchEngineConfig())
+ for inque, outque in zip(inques, outques):
+ event_loop.create_task(_batch_infer(inque, outque, pipe))
+ event_loop.run_forever()
+
+def client(inque, outque, message):
+ inque.put(message)
+ print(outque.get().text)
+
+
+inques = [Queue(), Queue()]
+outques = [Queue(), Queue()]
+
+t_server = threading.Thread(target=server, args=(inques, outques))
+t_client0 = threading.Thread(target=client, args=(inques[0], outques[0], 'Hakuna Matata'))
+t_client1 = threading.Thread(target=client, args=(inques[1], outques[1], 'giraffes are heartless creatures'))
+
+t_server.start()
+t_client0.start()
+t_client1.start()
+
+t_client0.join()
+t_client1.join()
+```
+
+> \[!WARNING\]
+> This is NOT recommended, as multithreading introduces additional overhead, leading to unstable inference performance.
diff --git a/docs/en/get_started/ascend/get_started.md b/docs/en/get_started/ascend/get_started.md
index 23b86afa61..7da28b5512 100644
--- a/docs/en/get_started/ascend/get_started.md
+++ b/docs/en/get_started/ascend/get_started.md
@@ -18,7 +18,7 @@ cd lmdeploy
### Environment Preparation
-The Docker version is supposed to be no less than `18.03`. And `Ascend Docker Runtime` should be installed by following [the official guide](https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/clusterscheduling/clusterschedulingig/.clusterschedulingig/dlug_installation_012.html).
+The Docker version is supposed to be no less than `18.09`. And `Ascend Docker Runtime` should be installed by following [the official guide](https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/clusterscheduling/clusterschedulingig/.clusterschedulingig/dlug_installation_012.html).
> \[!CAUTION\]
> If error message `libascend_hal.so: cannot open shared object file` shows, that means **Ascend Docker Runtime** is not installed correctly!
@@ -136,3 +136,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu
```
Please check [supported_models](../../supported_models/supported_models.md) before use this feature.
+
+### int8 KV-cache Quantization
+
+Ascend backend has supported offline int8 KV-cache Quantization on eager mode.
+
+Please refer this [doc](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md) for details.
diff --git a/docs/en/get_started/installation.md b/docs/en/get_started/installation.md
index b3e8bb8abd..8877d510cc 100644
--- a/docs/en/get_started/installation.md
+++ b/docs/en/get_started/installation.md
@@ -23,7 +23,7 @@ pip install lmdeploy
The default prebuilt package is compiled on **CUDA 12**. If CUDA 11+ (>=11.3) is required, you can install lmdeploy by:
```shell
-export LMDEPLOY_VERSION=0.6.3
+export LMDEPLOY_VERSION=0.6.5
export PYTHON_VERSION=38
pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118
```
diff --git a/docs/en/index.rst b/docs/en/index.rst
index 5d49e01c86..54a36c22c8 100644
--- a/docs/en/index.rst
+++ b/docs/en/index.rst
@@ -103,6 +103,7 @@ Documentation
advance/chat_template.md
advance/debug_turbomind.md
advance/structed_output.md
+ advance/pytorch_multithread.md
.. toctree::
:maxdepth: 1
diff --git a/docs/en/llm/api_server.md b/docs/en/llm/api_server.md
index 285b0e32ff..274ec2ff25 100644
--- a/docs/en/llm/api_server.md
+++ b/docs/en/llm/api_server.md
@@ -249,6 +249,57 @@ curl http://{server_ip}:{server_port}/v1/chat/interactive \
lmdeploy serve gradio api_server_url --server-name ${gradio_ui_ip} --server-port ${gradio_ui_port}
```
+## Launch multiple api servers
+
+Following are two steps to launch multiple api servers through torchrun. Just create a python script with the following codes.
+
+1. Launch the proxy server through `lmdeploy serve proxy`. Get the correct proxy server url.
+2. Launch the script through `torchrun --nproc_per_node 2 script.py InternLM/internlm2-chat-1_8b --proxy_url http://{proxy_node_name}:{proxy_node_port}`.**Note**: Please do not use `0.0.0.0:8000` here, instead, we input the real ip name, `11.25.34.55:8000` for example.
+
+```python
+import os
+import socket
+from typing import List, Literal
+
+import fire
+
+
+def get_host_ip():
+ try:
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ s.connect(('8.8.8.8', 80))
+ ip = s.getsockname()[0]
+ finally:
+ s.close()
+ return ip
+
+
+def main(model_path: str,
+ tp: int = 1,
+ proxy_url: str = 'http://0.0.0.0:8000',
+ port: int = 23333,
+ backend: Literal['turbomind', 'pytorch'] = 'turbomind'):
+ local_rank = int(os.environ.get('LOCAL_RANK', -1))
+ world_size = int(os.environ.get('WORLD_SIZE', -1))
+ local_ip = get_host_ip()
+ if isinstance(port, List):
+ assert len(port) == world_size
+ port = port[local_rank]
+ else:
+ port += local_rank * 10
+ if (world_size - local_rank) % tp == 0:
+ rank_list = ','.join([str(local_rank + i) for i in range(tp)])
+ command = f'CUDA_VISIBLE_DEVICES={rank_list} lmdeploy serve api_server {model_path} '\
+ f'--server-name {local_ip} --server-port {port} --tp {tp} '\
+ f'--proxy-url {proxy_url} --backend {backend}'
+ print(f'running command: {command}')
+ os.system(command)
+
+
+if __name__ == '__main__':
+ fire.Fire(main)
+```
+
## FAQ
1. When user got `"finish_reason":"length"`, it means the session is too long to be continued. The session length can be
diff --git a/docs/en/multi_modal/llava.md b/docs/en/multi_modal/llava.md
index 8f052227d5..c374b67121 100644
--- a/docs/en/multi_modal/llava.md
+++ b/docs/en/multi_modal/llava.md
@@ -6,11 +6,17 @@ LMDeploy supports the following llava series of models, which are detailed in th
| :----------------------------------: | :--: | :------------------------: |
| llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch |
| llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch |
-| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind, PyTorch |
-| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind, PyTorch |
+| llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch |
+| llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch |
+| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind |
+| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind |
The next chapter demonstrates how to deploy an Llava model using LMDeploy, with [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) as an example.
+```{note}
+PyTorch engine removes the support of original llava models after v0.6.4. Please use their corresponding transformers models instead, which can be found in https://huggingface.co/llava-hf
+```
+
## Installation
Please install LMDeploy by following the [installation guide](../get_started/installation.md).
diff --git a/docs/en/multi_modal/qwen2_vl.md b/docs/en/multi_modal/qwen2_vl.md
index 8b59f84545..fd9f02abaa 100644
--- a/docs/en/multi_modal/qwen2_vl.md
+++ b/docs/en/multi_modal/qwen2_vl.md
@@ -4,7 +4,7 @@ LMDeploy supports the following Qwen-VL series of models, which are detailed in
| Model | Size | Supported Inference Engine |
| :----------: | :----: | :------------------------: |
-| Qwen-VL-Chat | - | TurboMind, Pytorch |
+| Qwen-VL-Chat | - | TurboMind |
| Qwen2-VL | 2B, 7B | PyTorch |
The next chapter demonstrates how to deploy an Qwen-VL model using LMDeploy, with [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) as an example.
diff --git a/docs/en/quantization/w4a16.md b/docs/en/quantization/w4a16.md
index 0aa1e17a5b..c36c3736c6 100644
--- a/docs/en/quantization/w4a16.md
+++ b/docs/en/quantization/w4a16.md
@@ -128,3 +128,7 @@ We benchmarked the Llama-2-7B-chat and Llama-2-13B-chat models with 4-bit quanti
| ---------------- | ------- | ------- | --------- |
| Llama-2-7B-chat | 112.9 | 159.4 | 206.4 |
| Llama-2-13B-chat | N/A | 90.7 | 115.8 |
+
+## FAQs
+
+1. Out of Memory error during quantization due to insufficient GPU memory: This can be addressed by reducing the parameter `--calib-seqlen`, increasing the parameter `--calib-samples`, and set `--batch-size` to 1.
diff --git a/docs/en/quantization/w8a8.md b/docs/en/quantization/w8a8.md
index 1b1726bd5f..5cdb48f764 100644
--- a/docs/en/quantization/w8a8.md
+++ b/docs/en/quantization/w8a8.md
@@ -1,55 +1,74 @@
# SmoothQuant
-LMDeploy provides functions for quantization and inference of large language models using 8-bit integers.
+LMDeploy provides functions for quantization and inference of large language models using 8-bit integers(INT8). For GPUs such as Nvidia H100, lmdeploy also supports 8-bit floating point(FP8).
-Before starting inference, ensure that lmdeploy and openai/triton are correctly installed. Execute the following commands to install these:
+And the following NVIDIA GPUs are available for INT8/FP8 inference respectively:
+
+- INT8
+ - V100(sm70): V100
+ - Turing(sm75): 20 series, T4
+ - Ampere(sm80,sm86): 30 series, A10, A16, A30, A100
+ - Ada Lovelace(sm89): 40 series
+ - Hopper(sm90): H100
+- FP8
+ - Ada Lovelace(sm89): 40 series
+ - Hopper(sm90): H100
+
+First of all, run the following command to install lmdeploy:
```shell
-pip install lmdeploy
-pip install triton>=2.1.0
+pip install lmdeploy[all]
```
-## 8-bit Weight Model Inference
+## 8-bit Weight Quantization
-For performing 8-bit weight model inference, you can directly download the pre-quantized 8-bit weight models from LMDeploy's [model zoo](https://huggingface.co/lmdeploy). For instance, the 8-bit Internlm-chat-7B model is available for direct download from the model zoo:
+Performing 8-bit weight quantization involves three steps:
-```shell
-git-lfs install
-git clone https://huggingface.co/lmdeploy/internlm-chat-7b-w8 (coming soon)
-```
+1. **Smooth Weights**: Start by smoothing the weights of the Language Model (LLM). This process makes the weights more amenable to quantizing.
+2. **Replace Modules**: Locate DecoderLayers and replace the modules RSMNorm and nn.Linear with QRSMNorm and QLinear modules respectively. These 'Q' modules are available in the lmdeploy/pytorch/models/q_modules.py file.
+3. **Save the Quantized Model**: Once you've made the necessary replacements, save the new quantized model.
-Alternatively, you can manually convert original 16-bit weights into 8-bit by referring to the content under the ["8bit Weight Quantization"](#8bit-weight-quantization) section. Save them in the internlm-chat-7b-w8 directory, using the command below:
+lmdeploy provides `lmdeploy lite smooth_quant` command to accomplish all three tasks detailed above. Note that the argument `--quant-dtype` is used to determine if you are doing int8 or fp8 weight quantization. To get more info about usage of the cli, run `lmdeploy lite smooth_quant --help`
-```shell
-lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8
-```
+Here are two examples:
-Afterwards, use the following command to interact with the model via the terminal:
+- int8
-```shell
-lmdeploy chat ./internlm-chat-7b-w8 --backend pytorch
-```
+ ```shell
+ lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-int8 --quant-dtype int8
+ ```
-## Launching gradio service
+- fp8
-Coming soon...
+ ```shell
+ lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-fp8 --quant-dtype fp8
+ ```
-## Inference Speed
+## Inference
-Coming soon...
+Trying the following codes, you can perform the batched offline inference with the quantized model:
-## 8bit Weight Quantization
+```python
+from lmdeploy import pipeline, PytorchEngineConfig
-Performing 8bit weight quantization involves three steps:
+engine_config = PytorchEngineConfig(tp=1)
+pipe = pipeline("internlm2_5-7b-chat-int8", backend_config=engine_config)
+response = pipe(["Hi, pls intro yourself", "Shanghai is"])
+print(response)
+```
-1. **Smooth Weights**: Start by smoothing the weights of the Language Model (LLM). This process makes the weights more amenable to quantizing.
-2. **Replace Modules**: Locate DecoderLayers and replace the modules RSMNorm and nn.Linear with QRSMNorm and QLinear modules respectively. These 'Q' modules are available in the lmdeploy/pytorch/models/q_modules.py file.
-3. **Save the Quantized Model**: Once you've made the necessary replacements, save the new quantized model.
+## Service
+
+LMDeploy's `api_server` enables models to be easily packed into services with a single command. The provided RESTful APIs are compatible with OpenAI's interfaces. Below are an example of service startup:
+
+```shell
+lmdeploy serve api_server ./internlm2_5-7b-chat-int8 --backend pytorch
+```
-The script `lmdeploy/lite/apis/smooth_quant.py` accomplishes all three tasks detailed above. For example, you can obtain the model weights of the quantized Internlm-chat-7B model by running the following command:
+The default port of `api_server` is `23333`. After the server is launched, you can communicate with server on terminal through `api_client`:
```shell
-lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8
+lmdeploy serve api_client http://0.0.0.0:23333
```
-After saving, you can instantiate your quantized model by calling the from_pretrained interface.
+You can overview and try out `api_server` APIs online by swagger UI at `http://0.0.0.0:23333`, or you can also read the API specification from [here](../llm/api_server.md).
diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md
index 469ece487f..cb9805bb0b 100644
--- a/docs/en/supported_models/supported_models.md
+++ b/docs/en/supported_models/supported_models.md
@@ -4,97 +4,107 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
## TurboMind on CUDA Platform
-| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 |
-| :-------------------: | :------------: | :--: | :-------: | :-----: | :-----: | :---: |
-| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes |
-| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes |
-| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
-| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
-| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes |
-| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
-| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
-| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes |
-| InternLM-XComposer2 | 7B, 4khd-7B | MLLM | Yes | Yes | Yes | Yes |
-| InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes |
-| Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes |
-| Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes |
-| Qwen2 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes |
-| Mistral | 7B | LLM | Yes | Yes | Yes | No |
-| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes |
-| Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
-| DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
-| Baichuan | 7B | LLM | Yes | Yes | Yes | Yes |
-| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes |
-| Code Llama | 7B - 34B | LLM | Yes | Yes | Yes | No |
-| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes |
-| LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes |
-| InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes |
-| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes | Yes | Yes |
-| ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes |
-| MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes |
-| MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes |
-| MiniGeminiLlama | 7B | MLLM | Yes | - | - | Yes |
-| GLM4 | 9B | LLM | Yes | Yes | Yes | Yes |
-| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - |
-| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No |
+| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 |
+| :------------------------------: | :--------------: | :--: | :-------: | :-----: | :-----: | :---: |
+| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes |
+| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes |
+| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
+| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
+| Llama3.2\[2\] | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes |
+| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
+| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
+| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes |
+| InternLM-XComposer2 | 7B, 4khd-7B | MLLM | Yes | Yes | Yes | Yes |
+| InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes |
+| Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes |
+| Qwen1.5\[1\] | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes |
+| Qwen2\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes |
+| Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes |
+| Qwen2.5\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes |
+| Mistral\[1\] | 7B | LLM | Yes | Yes | Yes | No |
+| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes |
+| DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No |
+| DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No |
+| Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
+| DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
+| Baichuan | 7B | LLM | Yes | Yes | Yes | Yes |
+| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes |
+| Code Llama | 7B - 34B | LLM | Yes | Yes | Yes | No |
+| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes |
+| LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes |
+| InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes |
+| InternVL2\[2\] | 1 - 2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes |
+| InternVL2.5(MPO)\[2\] | 1 - 78B | MLLM | Yes | Yes\* | Yes\* | Yes |
+| ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes |
+| MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes |
+| MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes |
+| MiniGeminiLlama | 7B | MLLM | Yes | - | - | Yes |
+| GLM4 | 9B | LLM | Yes | Yes | Yes | Yes |
+| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - |
+| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No |
"-" means not verified yet.
```{note}
-The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference.
+* [1] The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference.
+* [2] When the head_dim of a model is not 128, such as llama3.2-1B, qwen2-0.5B and internvl2-1B, turbomind doesn't support its kv cache 4/8 bit quantization and inference
```
## PyTorchEngine on CUDA Platform
-| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 |
-| :------------: | :---------: | :--: | :-------: | :-----: | :-----: | :--: | :---: |
-| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3.2-VL | 11B, 90B | MLLM | Yes | Yes | Yes | - | - |
-| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
-| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
-| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | No |
-| Baichuan2 | 13B | LLM | Yes | Yes | Yes | No | No |
-| ChatGLM2 | 6B | LLM | Yes | Yes | Yes | No | No |
-| Falcon | 7B - 180B | LLM | Yes | Yes | Yes | No | No |
-| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Mistral | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | No | No |
-| QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | Yes |
-| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes |
-| QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No |
-| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
-| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | No |
-| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
-| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No |
-| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No |
-| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | Yes | Yes |
-| Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No |
-| Dbrx | 132B | LLM | Yes | Yes | Yes | No | No |
-| StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No |
-| Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | - | - |
-| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - |
-| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - |
-| LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | - | - |
-| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | Yes | Yes |
-| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | - | - |
-| Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | - | - |
-| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - |
-| Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | - | - |
-| GLM4 | 9B | LLM | Yes | Yes | Yes | No | No |
-| GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | No |
-| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | - |
-| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - |
-| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - |
-| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - |
+| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 |
+| :----------------------------: | :---------: | :--: | :-------: | :-----: | :-----: | :--: | :---: |
+| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3.2-VL | 11B, 90B | MLLM | Yes | Yes | Yes | - | - |
+| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
+| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
+| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | No |
+| Baichuan2 | 13B | LLM | Yes | Yes | Yes | No | No |
+| ChatGLM2 | 6B | LLM | Yes | Yes | Yes | No | No |
+| Falcon | 7B - 180B | LLM | Yes | Yes | Yes | No | No |
+| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Mistral | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | No | No |
+| QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | Yes |
+| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes |
+| QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No |
+| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
+| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
+| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes |
+| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
+| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No |
+| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No |
+| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No |
+| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes |
+| Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No |
+| Dbrx | 132B | LLM | Yes | Yes | Yes | No | No |
+| StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No |
+| Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | - | - |
+| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - |
+| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - |
+| LLaVA(1.5,1.6)\[2\] | 7B-34B | MLLM | No | No | No | No | No |
+| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes |
+| InternVL2 | 1B-76B | MLLM | Yes | Yes | Yes | - | - |
+| InternVL2.5(MPO) | 1B-78B | MLLM | Yes | Yes | Yes | - | - |
+| Mono-InternVL\[1\] | 2B | MLLM | Yes | Yes | Yes | - | - |
+| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - |
+| Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | - | - |
+| GLM4 | 9B | LLM | Yes | Yes | Yes | No | No |
+| GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | Yes |
+| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | - |
+| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - |
+| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - |
+| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - |
```{note}
-* Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead.
+* [1] Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead.
+* [2] PyTorch engine removes the support of original llava models after v0.6.4. Please use their corresponding transformers models instead, which can be found in https://huggingface.co/llava-hf
```
## PyTorchEngine on Huawei Ascend Platform
diff --git a/docs/zh_cn/advance/pytorch_multithread.md b/docs/zh_cn/advance/pytorch_multithread.md
new file mode 100644
index 0000000000..ebd68f503e
--- /dev/null
+++ b/docs/zh_cn/advance/pytorch_multithread.md
@@ -0,0 +1,78 @@
+# PyTorchEngine 多线程推理
+
+自 [PR2907](https://github.com/InternLM/lmdeploy/pull/2907) 起,我们废除了 PytorchEngine 的 thread_safe 模式以保证引擎能够更高效的运行。我们鼓励用户尽可能使用**服务接口**或**协程**来实现高并发,比如:
+
+```python
+import asyncio
+from lmdeploy import pipeline, PytorchEngineConfig
+
+event_loop = asyncio.new_event_loop()
+asyncio.set_event_loop(event_loop)
+
+model_path = 'Llama-3.2-1B-Instruct'
+pipe = pipeline(model_path, backend_config=PytorchEngineConfig())
+
+async def _gather_output():
+ tasks = [
+ pipe.async_batch_infer('Hakuna Matata'),
+ pipe.async_batch_infer('giraffes are heartless creatures'),
+ ]
+ return await asyncio.gather(*tasks)
+
+output = asyncio.run(_gather_output())
+print(output[0].text)
+print(output[1].text)
+```
+
+如果你确实有多线程推理的需求,那么可以进行简单的封装,来实现类似的效果。
+
+```python
+import threading
+from queue import Queue
+import asyncio
+from lmdeploy import pipeline, PytorchEngineConfig
+
+model_path = 'Llama-3.2-1B-Instruct'
+
+
+async def _batch_infer(inque: Queue, outque: Queue, pipe):
+ while True:
+ if inque.empty():
+ await asyncio.sleep(0)
+ continue
+
+ input = inque.get_nowait()
+ output = await pipe.async_batch_infer(input)
+ outque.put(output)
+
+
+def server(inques, outques):
+ event_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(event_loop)
+ pipe = pipeline(model_path, backend_config=PytorchEngineConfig())
+ for inque, outque in zip(inques, outques):
+ event_loop.create_task(_batch_infer(inque, outque, pipe))
+ event_loop.run_forever()
+
+def client(inque, outque, message):
+ inque.put(message)
+ print(outque.get().text)
+
+
+inques = [Queue(), Queue()]
+outques = [Queue(), Queue()]
+
+t_server = threading.Thread(target=server, args=(inques, outques))
+t_client0 = threading.Thread(target=client, args=(inques[0], outques[0], 'Hakuna Matata'))
+t_client1 = threading.Thread(target=client, args=(inques[1], outques[1], 'giraffes are heartless creatures'))
+
+t_server.start()
+t_client0.start()
+t_client1.start()
+
+t_client0.join()
+t_client1.join()
+```
+
+> \[!WARNING\]
+> 我们不鼓励这样实现,多线程会带来额外的开销,使得推理性能不稳定
diff --git a/docs/zh_cn/get_started/ascend/get_started.md b/docs/zh_cn/get_started/ascend/get_started.md
index b137c458be..e4790253cd 100644
--- a/docs/zh_cn/get_started/ascend/get_started.md
+++ b/docs/zh_cn/get_started/ascend/get_started.md
@@ -17,7 +17,7 @@ cd lmdeploy
### 环境准备
-Docker 版本应不低于 18.03。并且需按照[官方指南](https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/clusterscheduling/clusterschedulingig/clusterschedulingig/dlug_installation_012.html)安装 Ascend Docker Runtime。
+Docker 版本应不低于 18.09。并且需按照[官方指南](https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/clusterscheduling/clusterschedulingig/clusterschedulingig/dlug_installation_012.html)安装 Ascend Docker Runtime。
> \[!CAUTION\]
> 如果在后续容器内出现`libascend_hal.so: cannot open shared object file`错误,说明Ascend Docker Runtime没有被正确安装。
@@ -133,3 +133,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu
```
支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)。
+
+### int8 KV-cache 量化
+
+昇腾后端现在支持了在eager模式下的离线int8 KV-cache量化。
+
+详细使用方式请请参考这篇[文章](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md)。
diff --git a/docs/zh_cn/get_started/installation.md b/docs/zh_cn/get_started/installation.md
index 12562c51d5..501f8a13e8 100644
--- a/docs/zh_cn/get_started/installation.md
+++ b/docs/zh_cn/get_started/installation.md
@@ -23,7 +23,7 @@ pip install lmdeploy
默认的预构建包是在 **CUDA 12** 上编译的。如果需要 CUDA 11+ (>=11.3),你可以使用以下命令安装 lmdeploy:
```shell
-export LMDEPLOY_VERSION=0.6.3
+export LMDEPLOY_VERSION=0.6.5
export PYTHON_VERSION=38
pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118
```
diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst
index 018a00487f..197e800d58 100644
--- a/docs/zh_cn/index.rst
+++ b/docs/zh_cn/index.rst
@@ -104,6 +104,7 @@ LMDeploy 工具箱提供以下核心功能:
advance/chat_template.md
advance/debug_turbomind.md
advance/structed_output.md
+ advance/pytorch_multithread.md
.. toctree::
:maxdepth: 1
diff --git a/docs/zh_cn/llm/api_server.md b/docs/zh_cn/llm/api_server.md
index d6c0c42aef..8bb91c619e 100644
--- a/docs/zh_cn/llm/api_server.md
+++ b/docs/zh_cn/llm/api_server.md
@@ -258,6 +258,89 @@ curl http://{server_ip}:{server_port}/v1/chat/interactive \
}'
```
+## 同时启动多个 api_server
+
+两步直接启动多机多卡服务。先用下面的代码创建一个启动脚本。然后:
+
+1. 启动代理服务 `lmdeploy serve proxy`。
+2. torchrun 启动脚本 `torchrun --nproc_per_node 2 script.py InternLM/internlm2-chat-1_8b --proxy_url http://{proxy_node_name}:{proxy_node_port}`. **注意**: 多机多卡不要用默认 url `0.0.0.0:8000`,我们需要输入真实ip对应的地址,如:`11.25.34.55:8000`。多机情况下,因为不需要子节点间的通信,所以并不需要用户指定 torchrun 的 `--nnodes` 等参数,只要能保证每个节点执行一次单节点的 torchrun 就行。
+
+```python
+import os
+import socket
+from typing import List, Literal
+
+import fire
+
+
+def get_host_ip():
+ try:
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ s.connect(('8.8.8.8', 80))
+ ip = s.getsockname()[0]
+ finally:
+ s.close()
+ return ip
+
+
+def main(model_path: str,
+ tp: int = 1,
+ proxy_url: str = 'http://0.0.0.0:8000',
+ port: int = 23333,
+ backend: Literal['turbomind', 'pytorch'] = 'turbomind'):
+ local_rank = int(os.environ.get('LOCAL_RANK', -1))
+ world_size = int(os.environ.get('WORLD_SIZE', -1))
+ local_ip = get_host_ip()
+ if isinstance(port, List):
+ assert len(port) == world_size
+ port = port[local_rank]
+ else:
+ port += local_rank * 10
+ if (world_size - local_rank) % tp == 0:
+ rank_list = ','.join([str(local_rank + i) for i in range(tp)])
+ command = f'CUDA_VISIBLE_DEVICES={rank_list} lmdeploy serve api_server {model_path} '\
+ f'--server-name {local_ip} --server-port {port} --tp {tp} '\
+ f'--proxy-url {proxy_url} --backend {backend}'
+ print(f'running command: {command}')
+ os.system(command)
+
+
+if __name__ == '__main__':
+ fire.Fire(main)
+```
+
+### 示例
+
+为了进一步展示如何在集群环境中使用多机多卡服务。下面提供一个在火山云的用例:
+
+```shell
+#!/bin/bash
+# 激活 conda 环境
+source /path/to/your/home/miniconda3/bin/activate /path/to/your/home/miniconda3/envs/your_env
+export HOME=/path/to/your/home
+# 获取主节点IP地址(假设 MLP_WORKER_0_HOST 是主节点的IP)
+MASTER_IP=${MLP_WORKER_0_HOST}
+# 检查是否为主节点
+if [ "${MLP_ROLE_INDEX}" -eq 0 ]; then
+ # 启动 lmdeploy serve proxy 并放入后台
+ echo "Starting lmdeploy serve proxy on master node..."
+ PROXY_PORT=8000
+ lmdeploy serve proxy --server-name ${MASTER_IP} --server-port ${PROXY_PORT} &
+else
+ # 这里我们默认调度平台同时启动了所有机器,否则要sleep一会,等待 proxy 启动成功
+ echo "Not starting lmdeploy serve proxy on worker node ${MLP_ROLE_INDEX}."
+fi
+# 启动 torchrun 并放入后台
+# 再次强调多机环境下并不需要传--nnodes 或者 --master-addr 等参数,相当于每个机器上执行一次单节点的 torchrun 即可。
+torchrun \
+--nproc_per_node=${MLP_WORKER_GPU} \
+/path/to/script.py \
+InternLM/internlm2-chat-1_8b 8 http://${MASTER_IP}:${PROXY_PORT}
+# 打印主机的IP地址
+echo "Host IP addresses:"
+hostname -I
+```
+
## 接入 WebUI
LMDeploy 提供 gradio 和 [OpenAOE](https://github.com/InternLM/OpenAOE) 两种方式,为 api_server 接入 WebUI。
diff --git a/docs/zh_cn/multi_modal/llava.md b/docs/zh_cn/multi_modal/llava.md
index c40f37308a..6538d1b861 100644
--- a/docs/zh_cn/multi_modal/llava.md
+++ b/docs/zh_cn/multi_modal/llava.md
@@ -6,11 +6,17 @@ LMDeploy 支持以下 LLaVA 系列模型,具体如下表所示:
| :----------------------------------: | :--: | :----------------: |
| llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch |
| llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch |
-| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind, PyTorch |
-| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind, PyTorch |
+| llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch |
+| llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch |
+| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind |
+| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind |
接下来的章节将演示如何使用 LMDeploy 部署 LLaVA 模型,并以 [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) 为例。
+```{note}
+自 0.6.4 之后,PyTorch 引擎移除了对 llava 原始模型的支持。我们建议使用它们对应的 transformers 格式的模型。这些模型可以在 https://huggingface.co/llava-hf 中找到
+```
+
## 安装
请按照[安装指南](../get_started/installation.md)安装 LMDeploy。
diff --git a/docs/zh_cn/multi_modal/qwen2_vl.md b/docs/zh_cn/multi_modal/qwen2_vl.md
index f62d2de74c..7cb7efe93b 100644
--- a/docs/zh_cn/multi_modal/qwen2_vl.md
+++ b/docs/zh_cn/multi_modal/qwen2_vl.md
@@ -4,7 +4,7 @@ LMDeploy 支持 Qwen-VL 系列模型,具体如下:
| Model | Size | Supported Inference Engine |
| :----------: | :----: | :------------------------: |
-| Qwen-VL-Chat | - | TurboMind, Pytorch |
+| Qwen-VL-Chat | - | TurboMind |
| Qwen2-VL | 2B, 7B | PyTorch |
本文将以[Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)为例,演示使用 LMDeploy 部署 Qwen2-VL 系列模型的方法
diff --git a/docs/zh_cn/quantization/w4a16.md b/docs/zh_cn/quantization/w4a16.md
index d69a8a23d2..3cea164dd9 100644
--- a/docs/zh_cn/quantization/w4a16.md
+++ b/docs/zh_cn/quantization/w4a16.md
@@ -131,3 +131,8 @@ lmdeploy serve api_client http://0.0.0.0:23333
| ---------------- | ------- | ------- | --------- |
| Llama-2-7B-chat | 112.9 | 159.4 | 206.4 |
| Llama-2-13B-chat | N/A | 90.7 | 115.8 |
+
+## 快速问答
+
+1. 量化时出现 Out of Memory 显存不够:可以通过减小传参 `--calib-seqlen`,增大传参 `--calib-samples`,并使用 `--batch-size` 为 1。
+2. 量化时,无法链接huggingface并下载数据集。可以尝试使用镜像,`export HF_ENDPOINT=https://hf-mirror.com`。
diff --git a/docs/zh_cn/quantization/w8a8.md b/docs/zh_cn/quantization/w8a8.md
index 302dd538fd..3a63c82f8c 100644
--- a/docs/zh_cn/quantization/w8a8.md
+++ b/docs/zh_cn/quantization/w8a8.md
@@ -1,56 +1,76 @@
# W8A8 LLM 模型部署
-LMDeploy 提供了使用 8 bit 整数对神经网络模型进行量化和推理的功能。
+LMDeploy 提供了使用 8-bit 整数(INT8)和浮点数(FP8)对神经网络模型进行量化和推理的功能。
-在开始推理前,需要确保已经正确安装了 lmdeploy 和 openai/triton。可以通过以下命令进行安装:
+可用于 INT8 和 FP8 推理的 NVIDIA GPU 分别为:
+
+- INT8
+ - V100(sm70): V100
+ - Turing(sm75): 20 series, T4
+ - Ampere(sm80,sm86): 30 series, A10, A16, A30, A100
+ - Ada Lovelace(sm89): 40 series
+ - Hopper(sm90): H100
+- FP8
+ - Ada Lovelace(sm89): 40 series
+ - Hopper(sm90): H100
+
+首先,执行如下命令安装lmdeploy:
```shell
-pip install lmdeploy
-pip install triton>=2.1.0
+pip install lmdeploy[all]
```
-## 8bit 权重模型推理
+## 8-bit 权重量化
-如果你需要进行 8 bit 权重模型推理,可以直接从 LMDeploy 的 [model zoo](https://huggingface.co/lmdeploy) 下载已经量化好的 8bit 权重模型。以8bit 的 Internlm-chat-7B 模型为例,可以从 model zoo 直接下载:
+进行 8-bit 权重量化需要经历以下三步:
-```shell
-git-lfs install
-git clone https://huggingface.co/lmdeploy/internlm-chat-7b-w8 (coming soon)
-```
+1. **权重平滑**:首先对语言模型的权重进行平滑处理,以便更好地进行量化。
+2. **模块替换**:使用 `QRMSNorm` 和 `QLinear` 模块替换原模型 `DecoderLayer` 中的 `RMSNorm` 模块和 `nn.Linear` 模块。`lmdeploy/pytorch/models/q_modules.py` 文件中定义了这些量化模块。
+3. **保存量化模型**:完成上述必要的替换后,我们即可保存新的量化模型。
-你也可以参考["8bit 权重量化"](#8bit-权重量化)章节的内容手动将原 16bit 权重量化为 8bit,并保存至 `internlm-chat-7b-w8` 目录下,操作命令如下:
+lmdeploy 提供了命令行工具 `lmdeploy lite smooth_quant` 实现了以上三个步骤。并且其中命令行参数 `--quant-dtype` 可以用来控制是进行8-bit整数还是浮点数类型的量化。更多命令行工具使用方式,请执行 `lmdeploy lite smooth_quant --help` 查看。
-```shell
-lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8
-```
+以下示例演示了进行 int8 或 fp8 的量化命令。
-然后,执行以下命令,即可在终端与模型对话:
+- int8
-```shell
-lmdeploy chat ./internlm-chat-7b-w8 --backend pytorch
-```
+ ```shell
+ lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-int8 --quant-dtype int8
+ ```
-## 启动 gradio 服务
+- fp8
-Coming soon...
+ ```shell
+ lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-fp8 --quant-dtype fp8
+ ```
-## 推理速度
+## 模型推理
-Coming soon...
+量化后的模型,通过以下几行简单的代码,可以实现离线推理:
-## 8bit 权重量化
+```python
+from lmdeploy import pipeline, PytorchEngineConfig
-进行 8bit 权重量化需要经历以下三步:
+engine_config = PytorchEngineConfig(tp=1)
+pipe = pipeline("internlm2_5-7b-chat-int8", backend_config=engine_config)
+response = pipe(["Hi, pls intro yourself", "Shanghai is"])
+print(response)
+```
-1. **权重平滑**:首先对语言模型的权重进行平滑处理,以便更好地进行量化。
-2. **模块替换**:使用 `QRSMNorm` 和 `QLinear` 模块替换原模型 `DecoderLayer` 中的 `RSMNorm` 模块和 `nn.Linear` 模块。`lmdeploy/pytorch/models/q_modules.py` 文件中定义了这些量化模块。
-3. **保存量化模型**:完成上述必要的替换后,我们即可保存新的量化模型。
+关于 pipeline 的详细介绍,请参考[这里](../llm/pipeline.md)
-我们在`lmdeploy/lite/api/smooth_quantity.py`脚本中已经实现了以上三个步骤。例如,可以通过以下命令得到量化后的 Internlm-chat-7B 模型的模型权重:
+## 推理服务
+
+LMDeploy `api_server` 支持把模型一键封装为服务,对外提供的 RESTful API 兼容 openai 的接口。以下为服务启动的示例:
```shell
+lmdeploy serve api_server ./internlm2_5-7b-chat-int8 --backend pytorch
+```
-lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8
+服务默认端口是23333。在 server 启动后,你可以在终端通过`api_client`与server进行对话:
+
+```shell
+lmdeploy serve api_client http://0.0.0.0:23333
```
-保存之后,你就可以通过调用from_pretrained接口来实例化你的量化模型。
+还可以通过 Swagger UI `http://0.0.0.0:23333` 在线阅读和试用 `api_server` 的各接口,也可直接查阅[文档](../llm/api_server.md),了解各接口的定义和使用方法。
diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md
index d734523282..83b7a9ca6f 100644
--- a/docs/zh_cn/supported_models/supported_models.md
+++ b/docs/zh_cn/supported_models/supported_models.md
@@ -4,97 +4,107 @@
## TurboMind CUDA 平台
-| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 |
-| :-------------------: | :------------: | :--: | :-------: | :-----: | :-----: | :---: |
-| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes |
-| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes |
-| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
-| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
-| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes |
-| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
-| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
-| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes |
-| InternLM-XComposer2 | 7B, 4khd-7B | MLLM | Yes | Yes | Yes | Yes |
-| InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes |
-| Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes |
-| Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes |
-| Qwen2 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes |
-| Mistral | 7B | LLM | Yes | Yes | Yes | No |
-| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes |
-| Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
-| DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
-| Baichuan | 7B | LLM | Yes | Yes | Yes | Yes |
-| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes |
-| Code Llama | 7B - 34B | LLM | Yes | Yes | Yes | No |
-| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes |
-| LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes |
-| InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes |
-| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes | Yes | Yes |
-| ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes |
-| MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes |
-| MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes |
-| MiniGeminiLlama | 7B | MLLM | Yes | - | - | Yes |
-| GLM4 | 9B | LLM | Yes | Yes | Yes | Yes |
-| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - |
-| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No |
+| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 |
+| :------------------------------: | :------------: | :--: | :-------: | :-----: | :-----: | :---: |
+| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes |
+| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes |
+| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
+| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
+| Llama3.2\[2\] | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes |
+| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
+| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
+| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes |
+| InternLM-XComposer2 | 7B, 4khd-7B | MLLM | Yes | Yes | Yes | Yes |
+| InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes |
+| Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes |
+| Qwen1.5\[1\] | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes |
+| Qwen2\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes |
+| Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes |
+| Qwen2.5\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes |
+| Mistral\[1\] | 7B | LLM | Yes | Yes | Yes | No |
+| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes |
+| DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No |
+| DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No |
+| Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
+| DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
+| Baichuan | 7B | LLM | Yes | Yes | Yes | Yes |
+| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes |
+| Code Llama | 7B - 34B | LLM | Yes | Yes | Yes | No |
+| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes |
+| LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes |
+| InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes |
+| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes |
+| InternVL2.5(MPO)\[2\] | 1 - 78B | MLLM | Yes | Yes\* | Yes\* | Yes |
+| ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes |
+| MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes |
+| MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes |
+| MiniGeminiLlama | 7B | MLLM | Yes | - | - | Yes |
+| GLM4 | 9B | LLM | Yes | Yes | Yes | Yes |
+| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - |
+| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No |
“-” 表示还没有验证。
```{note}
-turbomind 引擎不支持 window attention。所以,对于应用了 window attention,并开启了对应的开关"use_sliding_window"的模型,比如 Mistral、Qwen1.5 等,在推理时,请选择 pytorch engine
+* [1] turbomind 引擎不支持 window attention。所以,对于应用了 window attention,并开启了对应的开关"use_sliding_window"的模型,比如 Mistral、Qwen1.5 等,在推理时,请选择 pytorch engine
+* [2] 当模型的 head_dim 非 128 时,turbomind 不支持它的 kv cache 4/8 bit 量化和推理。比如,llama3.2-1B,qwen2-0.5B,internvl2-1B 等等
```
## PyTorchEngine CUDA 平台
-| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 |
-| :------------: | :---------: | :--: | :-------: | :-----: | :-----: | :--: | :---: |
-| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3.2-VL | 11B, 90B | MLLM | Yes | Yes | Yes | - | - |
-| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
-| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
-| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | No |
-| Baichuan2 | 13B | LLM | Yes | Yes | Yes | No | No |
-| ChatGLM2 | 6B | LLM | Yes | Yes | Yes | No | No |
-| Falcon | 7B - 180B | LLM | Yes | Yes | Yes | No | No |
-| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Mistral | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | No | No |
-| QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | Yes |
-| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes |
-| QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No |
-| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
-| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | No |
-| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
-| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No |
-| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No |
-| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | Yes | Yes |
-| Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No |
-| Dbrx | 132B | LLM | Yes | Yes | Yes | No | No |
-| StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No |
-| Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | - | - |
-| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - |
-| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - |
-| LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | - | - |
-| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | Yes | Yes |
-| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | - | - |
-| Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | - | - |
-| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - |
-| Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | - | - |
-| GLM4 | 9B | LLM | Yes | Yes | Yes | No | No |
-| GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | No |
-| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | - |
-| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - |
-| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - |
-| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - |
+| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 |
+| :----------------------------: | :---------: | :--: | :-------: | :-----: | :-----: | :--: | :---: |
+| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3.2-VL | 11B, 90B | MLLM | Yes | Yes | Yes | - | - |
+| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
+| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
+| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | No |
+| Baichuan2 | 13B | LLM | Yes | Yes | Yes | No | No |
+| ChatGLM2 | 6B | LLM | Yes | Yes | Yes | No | No |
+| Falcon | 7B - 180B | LLM | Yes | Yes | Yes | No | No |
+| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Mistral | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | No | No |
+| QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | Yes |
+| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes |
+| QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No |
+| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
+| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
+| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes |
+| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
+| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No |
+| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No |
+| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No |
+| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes |
+| Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No |
+| Dbrx | 132B | LLM | Yes | Yes | Yes | No | No |
+| StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No |
+| Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | - | - |
+| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - |
+| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - |
+| LLaVA(1.5,1.6)\[2\] | 7B-34B | MLLM | No | No | No | No | No |
+| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes |
+| InternVL2 | 1B-76B | MLLM | Yes | Yes | Yes | - | - |
+| InternVL2.5(MPO) | 1B-78B | MLLM | Yes | Yes | Yes | - | - |
+| Mono-InternVL\[1\] | 2B | MLLM | Yes\* | Yes | Yes | - | - |
+| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - |
+| Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | - | - |
+| GLM4 | 9B | LLM | Yes | Yes | Yes | No | No |
+| GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | Yes |
+| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | - |
+| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - |
+| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - |
+| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - |
```{note}
-* Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead.
+* [1] 目前,Mono-InternVL不支持FP16,因为数值不稳定。请改用BF16
+* [2] 自 0.6.4 之后,PyTorch 引擎移除了对 llava 模型原始格式的支持。我们建议使用它们对应的 transformers 格式的模型。这些模型可以在 https://huggingface.co/llava-hf 中找到
```
## PyTorchEngine 华为昇腾平台
diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py
index ce5cbd98ff..760a82b1c9 100644
--- a/lmdeploy/archs.py
+++ b/lmdeploy/archs.py
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
-from typing import Literal, Optional, Union
+from typing import Dict, List, Literal, Optional, Union
from transformers import AutoConfig
@@ -128,7 +128,8 @@ def check_vl_llm(config: dict) -> bool:
return True
elif arch == 'MultiModalityCausalLM' and 'language_config' in config:
return True
- elif arch == 'ChatGLMModel' and 'vision_config' in config:
+ elif arch in ['ChatGLMModel', 'ChatGLMForConditionalGeneration'
+ ] and 'vision_config' in config:
return True
elif arch in supported_archs:
return True
@@ -193,3 +194,22 @@ def get_model_arch(model_path: str):
raise RuntimeError(
f'Could not find model architecture from config: {_cfg}')
return arch, cfg
+
+
+def search_nested_config(config, key):
+ """Recursively searches for the value associated with the given key in a
+ nested configuration of a model."""
+ if isinstance(config, Dict):
+ for k, v in config.items():
+ if k == key:
+ return v
+ if isinstance(v, (Dict, List)):
+ result = search_nested_config(v, key)
+ if result is not None:
+ return result
+ elif isinstance(config, List):
+ for item in config:
+ result = search_nested_config(item, key)
+ if result is not None:
+ return result
+ return None
diff --git a/lmdeploy/cli/lite.py b/lmdeploy/cli/lite.py
index d76d6a5f34..499bace485 100644
--- a/lmdeploy/cli/lite.py
+++ b/lmdeploy/cli/lite.py
@@ -35,6 +35,7 @@ def add_parser_auto_awq():
ArgumentHelper.calib_seqlen(parser)
ArgumentHelper.calib_batchsize(parser)
ArgumentHelper.calib_search_scale(parser)
+ ArgumentHelper.dtype(parser)
parser.add_argument(
'--device',
type=str,
@@ -71,6 +72,7 @@ def add_parser_auto_gptq():
ArgumentHelper.calib_samples(parser)
ArgumentHelper.calib_seqlen(parser)
ArgumentHelper.calib_batchsize(parser)
+ ArgumentHelper.dtype(parser)
parser.add_argument('--w-bits',
type=int,
default=4,
@@ -99,6 +101,7 @@ def add_parser_calibrate():
ArgumentHelper.calib_seqlen(parser)
ArgumentHelper.calib_batchsize(parser)
ArgumentHelper.calib_search_scale(parser)
+ ArgumentHelper.dtype(parser)
@staticmethod
def add_parser_smooth_quant():
@@ -122,6 +125,8 @@ def add_parser_smooth_quant():
ArgumentHelper.calib_seqlen(parser)
ArgumentHelper.calib_batchsize(parser)
ArgumentHelper.calib_search_scale(parser)
+ ArgumentHelper.dtype(parser)
+ ArgumentHelper.quant_dtype(parser)
@staticmethod
def auto_awq(args):
diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py
index d4a0e54b1b..939d7a2f7b 100644
--- a/lmdeploy/cli/serve.py
+++ b/lmdeploy/cli/serve.py
@@ -239,6 +239,7 @@ def add_parser_proxy():
help='the strategy to dispatch requests to nodes')
ArgumentHelper.api_keys(parser)
ArgumentHelper.ssl(parser)
+ ArgumentHelper.log_level(parser)
@staticmethod
def gradio(args):
diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py
index 6db44930f4..33d5d339cf 100644
--- a/lmdeploy/cli/utils.py
+++ b/lmdeploy/cli/utils.py
@@ -122,6 +122,16 @@ def dtype(parser, default: str = 'auto'):
'for BF16 models. This option will be ignored if '
'the model is a quantized model')
+ @staticmethod
+ def quant_dtype(parser, default: str = 'int8'):
+ return parser.add_argument(
+ '--quant-dtype',
+ type=str,
+ default=default,
+ choices=['int8', 'float8_e4m3fn', 'float8_e5m2', 'fp8'],
+ help='data type for the quantized model weights and activations.'
+ 'Note "fp8" is the short version of "float8_e4m3fn"')
+
@staticmethod
def model_format(parser, default: str = None):
return parser.add_argument(
@@ -363,7 +373,7 @@ def calib_batchsize(parser):
@staticmethod
def calib_search_scale(parser):
- """Add argument batch_size to parser."""
+ """Add argument search_scale to parser."""
return parser.add_argument(
'--search-scale',
diff --git a/lmdeploy/lite/apis/auto_awq.py b/lmdeploy/lite/apis/auto_awq.py
index c41b28fd6e..2c84612839 100644
--- a/lmdeploy/lite/apis/auto_awq.py
+++ b/lmdeploy/lite/apis/auto_awq.py
@@ -2,6 +2,7 @@
import os
import os.path as osp
import shutil
+from typing import Literal
import torch
from torch import nn
@@ -12,9 +13,7 @@
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.pytorch.check_env import try_import_deeplink
-from .calibrate import LAYER_TYPE_MAP, NORM_TYPE_MAP, calibrate
-
-NORM_TYPE_MAP = NORM_TYPE_MAP # legacy
+from .calibrate import LAYER_TYPE_MAP, calibrate
def save_vl_model(vl_model, model_path, dst_path):
@@ -56,6 +55,7 @@ def auto_awq(model: str,
search_scale: bool = False,
device: str = 'cuda',
revision: str = None,
+ dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',
download_dir: str = None):
"""Perform weight quantization using AWQ algorithm.
@@ -77,6 +77,7 @@ def auto_awq(model: str,
revision (str): The specific model version to use. It can be a
branch name, a tag name, or a commit id. If unspecified,
will use the default version.
+ dtype (str): Data type for loading model weights and calib infer.
download_dir (str): Directory to download and load the weights,
default to the default cache directory of huggingface.
"""
@@ -96,6 +97,7 @@ def auto_awq(model: str,
w_bits=w_bits,
w_group_size=w_group_size,
search_scale=search_scale,
+ dtype=dtype,
batch_size=batch_size)
layer_type = LAYER_TYPE_MAP[type(model).__name__]
diff --git a/lmdeploy/lite/apis/calibrate.py b/lmdeploy/lite/apis/calibrate.py
index 71f7a5900c..007f831a70 100644
--- a/lmdeploy/lite/apis/calibrate.py
+++ b/lmdeploy/lite/apis/calibrate.py
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
-from typing import Union
+from typing import Literal, Union
import torch
from torch import nn
@@ -11,6 +11,7 @@
from lmdeploy.lite.quantization import CalibrationContext, CalibrationContextV2
from lmdeploy.lite.utils import (collect_target_modules, get_calib_loaders,
load_hf_from_pretrained)
+from lmdeploy.vl.model.builder import load_vl_model
LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
@@ -204,6 +205,7 @@ def calibrate(model: str,
w_bits: int = 4,
w_group_size: int = 128,
search_scale: bool = False,
+ dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',
batch_size: int = 1) -> None:
"""The main function for loading the model and performing calibration on a
given dataset.
@@ -224,6 +226,7 @@ def calibrate(model: str,
w_group_size (int): Group size for weight quantization statistics.
search_scale (bool): Whether search scale ratio. Default to False,
which means only smooth quant with 0.5 ratio will be applied.
+ dtype (str): Data type for loading model weights and calib infer.
batch_size (int): The batch size for running the calib samples.
Low GPU mem requires small batch_size. Large batch_size
reduces the calibration time while costs more VRAM.
@@ -239,20 +242,35 @@ def calibrate(model: str,
model_type, _ = get_task(model)
make_compatible_internvl_config(model)
- if model_type == 'llm':
- # Load tokenizer and configuration
- tokenizer = AutoTokenizer.from_pretrained(model,
- trust_remote_code=True)
+ # Load tokenizer and configuration
+ tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
+
+ if model_type == 'llm':
model = load_hf_from_pretrained(model,
- torch_dtype=torch.float16,
+ dtype=dtype,
trust_remote_code=True)
vl_model = None
elif model_type == 'vlm':
- from lmdeploy.vl.model.builder import vl_model_with_tokenizer
- vl_model, model, tokenizer = vl_model_with_tokenizer(model_path=model)
+ vl_model = load_vl_model(model, backend=None, with_llm=True).vl_model
+ model = vl_model
+ if hasattr(vl_model, 'language_model'): # deepseek-vl, ...
+ model = vl_model.language_model
+ if hasattr(vl_model, 'llm'): # MiniCPMV, ...
+ model = vl_model.llm
+ model.config.use_cache = False
+ if dtype == 'float16':
+ model.half()
+ elif dtype == 'bfloat16':
+ assert torch.cuda.is_bf16_supported(
+ ), 'your device does not support bfloat16 please set --dtype float16' # noqa
+ model.to(torch.bfloat16)
+ elif dtype == 'auto' and model.config.torch_dtype == torch.bfloat16:
+ print('Warning: we cast model to float16 to prevent OOM. You'
+ ' may enforce it bfloat16 by `--dtype bfloat16`')
+ model.half()
+ model.eval()
- model.config.use_cache = False
model_type = type(model).__name__
if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:
raise RuntimeError(
diff --git a/lmdeploy/lite/apis/gptq.py b/lmdeploy/lite/apis/gptq.py
index 12b88a52cd..eb4418a533 100644
--- a/lmdeploy/lite/apis/gptq.py
+++ b/lmdeploy/lite/apis/gptq.py
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
+from typing import Literal
import torch
-from transformers import AutoTokenizer
+from transformers import AutoConfig, AutoTokenizer
from lmdeploy.lite.utils.calib_dataloader import get_calib_loaders
@@ -15,6 +16,7 @@ def auto_gptq(model: str,
calib_samples: int = 128,
calib_seqlen: int = 2048,
batch_size: int = 1,
+ dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',
revision: str = None):
"""Perform weight quantization using AWQ algorithm.
@@ -29,9 +31,7 @@ def auto_gptq(model: str,
calib_seqlen (int): The sequence length for calibration.
w_bits (int): Bit number for weight quantization.
w_group_size (int): Group size for weight quantization statistics.
- search_scale (bool): Whether search scale ratio. Default to False,
- which means only smooth quant with 0.5 ratio will be applied.
- device (str): Device type of running.
+ dtype (str): Data type for loading model weights and calib infer.
revision (str): The specific model version to use. It can be a
branch name, a tag name, or a commit id. If unspecified,
will use the default version.
@@ -83,9 +83,18 @@ def auto_gptq(model: str,
# load un-quantized model, by default,
# the model will always be loaded into CPU memory
+ hf_config = AutoConfig.from_pretrained(pretrained_model_dir,
+ revision=revision,
+ trust_remote_code=True)
+ torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16)
+ if dtype == 'float16':
+ torch_dtype = torch.float16
+ elif dtype == 'bfloat16':
+ torch_dtype = torch.bfloat16
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir,
quantize_config,
revision=revision,
+ torch_dtype=torch_dtype,
trust_remote_code=True)
# quantize model, the examples should be list of dict whose keys
diff --git a/lmdeploy/lite/apis/smooth_quant.py b/lmdeploy/lite/apis/smooth_quant.py
index c8df67355e..8d67535bcc 100644
--- a/lmdeploy/lite/apis/smooth_quant.py
+++ b/lmdeploy/lite/apis/smooth_quant.py
@@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
+
+from typing import Literal
+
import fire
import torch
from torch import nn
@@ -6,7 +9,8 @@
from lmdeploy.lite.apis.calibrate import (LAYER_TYPE_MAP, NORM_TYPE_MAP,
calibrate)
from lmdeploy.lite.quantization.awq import (FC_FCS_MAP, NORM_FCS_MAP,
- awq_layers, smooth_layers)
+ awq_layers, skipped_module,
+ smooth_layers)
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.pytorch.models import QLinear, QRMSNorm
@@ -19,8 +23,20 @@ def smooth_quant(model: str,
search_scale: bool = False,
batch_size: int = 1,
w_bits: int = 8,
- device: str = 'cuda'):
+ dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',
+ device: str = 'cuda',
+ quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn',
+ 'float8_e5m2'] = 'int8'):
+ if quant_dtype == 'fp8':
+ quant_dtype = 'float8_e4m3fn'
+
+ quant_dtype = getattr(torch, quant_dtype, torch.int8)
+ if quant_dtype.is_floating_point:
+ q_dtype_info = torch.finfo(quant_dtype)
+ else:
+ q_dtype_info = torch.iinfo(quant_dtype)
+ assert q_dtype_info.bits == w_bits
model_path = model
vl_model, model, tokenizer, work_dir = calibrate(model,
calib_dataset,
@@ -31,6 +47,7 @@ def smooth_quant(model: str,
w_bits=w_bits,
w_group_size=-1,
search_scale=search_scale,
+ dtype=dtype,
batch_size=batch_size)
# calibrate function exports the calibration statistics
@@ -76,16 +93,20 @@ def smooth_quant(model: str,
rmsnorms = collect_target_modules(model, norm_type)
for name, linear in fcs.items():
+ if skipped_module(name):
+ continue
linear.to(device)
- q_linear = QLinear.from_float(linear)
+ q_linear = QLinear.from_float(linear, quant_dtype=quant_dtype)
parent_name, _, child_name = name.rpartition('.')
parent = model.get_submodule(parent_name)
setattr(parent, child_name, q_linear)
linear.to('cpu')
for name, norm in rmsnorms.items():
+ if skipped_module(name):
+ continue
norm.to(device)
- q_norm = QRMSNorm.from_float(norm)
+ q_norm = QRMSNorm.from_float(norm, quant_dtype=quant_dtype)
parent_name, _, child_name = name.rpartition('.')
parent = model.get_submodule(parent_name)
setattr(parent, child_name, q_norm)
@@ -95,8 +116,10 @@ def smooth_quant(model: str,
from .auto_awq import save_vl_model
save_vl_model(vl_model, model_path, work_dir)
else:
+ quant_dtype_s = str(quant_dtype).split('.')[1]
model.config.update(
- dict(quantization_config=dict(quant_method='smooth_quant')))
+ dict(quantization_config=dict(quant_method='smooth_quant',
+ quant_dtype=f'{quant_dtype_s}')))
model.save_pretrained(work_dir,
max_shard_size='2GB',
safe_serialization=False)
diff --git a/lmdeploy/lite/quantization/awq.py b/lmdeploy/lite/quantization/awq.py
index cf03a75216..3e24a13cc3 100644
--- a/lmdeploy/lite/quantization/awq.py
+++ b/lmdeploy/lite/quantization/awq.py
@@ -43,8 +43,10 @@
'MixtralDecoderLayer': {
'input_layernorm':
['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'],
- 'post_attention_layernorm':
- ['block_sparse_moe.experts.{i}.w1', 'block_sparse_moe.experts.{i}.w3']
+ 'post_attention_layernorm': [
+ 'block_sparse_moe.gate', 'block_sparse_moe.experts.{i}.w1',
+ 'block_sparse_moe.experts.{i}.w3'
+ ]
},
'Qwen2VLDecoderLayer': {
'input_layernorm':
@@ -120,7 +122,12 @@ def get_weight_scale(weight, q_group_size=-1):
org_shape = weight.shape
if q_group_size > 0:
weight = weight.view(-1, q_group_size)
- scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
+ abs_weight = weight.abs()
+ abs_weight_amax = abs_weight.amax(dim=1, keepdim=True)
+ if abs_weight_amax.min().item() == 0:
+ print('weight.amax.min is zero, clamping weight.amax to 1e-4')
+ abs_weight_amax = abs_weight_amax.clamp(min=1e-4)
+ scale = abs_weight / abs_weight_amax
scale = scale.view(org_shape)
scale = scale.mean(0)
return scale
@@ -153,8 +160,13 @@ def smooth_ln_fcs(ln: torch.nn.Module,
concat_w = torch.cat([fc.weight for fc in fcs], dim=0)
w_scales = get_weight_scale(concat_w, group_size)
+ w_scales_pow = w_scales.pow(1 - alpha)
+ if w_scales_pow.min().item() == 0:
+ print('w_scales.pow(1 - alpha).min is zero, '
+ 'clamping w_scales.pow(1 - alpha) to 1e-4')
+ w_scales_pow = w_scales_pow.clamp(min=1e-4)
scales = (act_scales.pow(alpha) /
- w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype)
+ w_scales_pow).clamp(min=1e-4).to(device).to(dtype)
scales = scales / (scales[nonzero_positions].max() *
scales[nonzero_positions].min()).sqrt()
@@ -204,8 +216,13 @@ def smooth_fc_fcs(pre_fc: torch.nn.Module,
concat_w = torch.cat([fc.weight for fc in fcs], dim=0)
w_scales = get_weight_scale(concat_w, group_size)
+ w_scales_pow = w_scales.pow(1 - alpha)
+ if w_scales_pow.min().item() == 0:
+ print('w_scales.pow(1 - alpha).min is zero, '
+ 'clamping w_scales.pow(1 - alpha) to 1e-4')
+ w_scales_pow = w_scales_pow.clamp(min=1e-4)
scales = (act_scales.pow(alpha) /
- w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype)
+ w_scales_pow).clamp(min=1e-4).to(device).to(dtype)
scales = scales / (scales.max() * scales.min()).sqrt()
# (for qwen&baichuan) pre_fc is packed QKV, only V needs to scale
diff --git a/lmdeploy/lite/quantization/calibration.py b/lmdeploy/lite/quantization/calibration.py
index e590f1a4eb..1df8f2c740 100644
--- a/lmdeploy/lite/quantization/calibration.py
+++ b/lmdeploy/lite/quantization/calibration.py
@@ -42,6 +42,9 @@ def __init__(self,
tokenizer (PreTrainedTokenizer): Tokenizer of the given model.
layer_type (Union[str, type]): Type of the layers to be observed.
norm_type (Union[str, type]): Norm type used in the model.
+ batch_size (int): The batch size for running the calib samples.
+ Low GPU mem requires small batch_size. Large batch_size
+ reduces the calibration time while costs more VRAM.
device (str, optional): Device where the model should run.
Defaults to 'cuda'.
"""
@@ -290,9 +293,14 @@ def _search_module_scale(block, linears2scale: list, x, kwargs={}):
org_sd = {k: v.cpu() for k, v in block.state_dict().items()}
for ratio in range(0, n_grid):
- ratio = ratio * 1 / n_grid
- scales = (x_max.pow(ratio) /
- w_mean.pow(1 - ratio)).clamp(min=1e-4).view(-1)
+ ratio = ratio / n_grid
+ w_mean_pow = w_mean.pow(1 - ratio)
+ if w_mean_pow.min().item() == 0:
+ print('w_mean.pow(1 - ratio).min is zero, '
+ 'clamping w_mean.pow(1 - ratio) to 1e-4')
+ w_mean_pow = w_mean_pow.clamp(min=1e-4)
+ scales = (x_max.pow(ratio) / w_mean_pow).clamp(min=1e-4).view(-1)
+
scales = scales / (scales.max() * scales.min()).sqrt()
for fc in linears2scale:
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
diff --git a/lmdeploy/lite/utils/load.py b/lmdeploy/lite/utils/load.py
index bfd306a743..ac4519371a 100644
--- a/lmdeploy/lite/utils/load.py
+++ b/lmdeploy/lite/utils/load.py
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Literal
+
import torch
from transformers import AutoConfig, AutoModelForCausalLM
@@ -7,29 +9,42 @@
def load_hf_from_pretrained(pretrained_model_name_or_path,
- dtype=torch.float16,
- **kwargs):
+ dtype: Literal['float16', 'bfloat16',
+ 'auto'], **kwargs):
- if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
+ if dtype == 'bfloat16' and not torch.cuda.is_bf16_supported():
raise RuntimeError('Your device does not supports bf16(bfloat16), '
'please change to fp16(float16)')
kwargs.pop('config', None)
hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path,
- torch_dtype=dtype,
trust_remote_code=True)
# HACK hard code for qwen, other configs do not have the `fp16` attribute.
- if dtype == torch.float16:
- hf_config.fp16 = True
- elif dtype == torch.bfloat16:
- hf_config.bf16 = True
+ if hasattr(hf_config, 'fp16') or hasattr(hf_config, 'bf16'):
+ if dtype == 'bfloat16':
+ hf_config.bf16 = True
+ else:
+ hf_config.fp16 = True
+
+ torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16)
+ if dtype == 'bfloat16':
+ torch_dtype = torch.bfloat16
+ elif dtype == 'float16':
+ torch_dtype = torch.float16
+ elif dtype == 'auto' and torch_dtype == torch.bfloat16:
+ print('Warning: we cast model to float16 to prevent OOM. '
+ 'You may enforce it bfloat16 by `--dtype bfloat16`')
+ torch_dtype = torch.float16
with LoadNoInit():
# Load model
model = AutoModelForCausalLM.from_pretrained(
- pretrained_model_name_or_path, config=hf_config, **kwargs)
+ pretrained_model_name_or_path,
+ config=hf_config,
+ torch_dtype=torch_dtype,
+ **kwargs)
model.config.use_cache = False
return model
diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py
index 90823598ea..2336d10752 100644
--- a/lmdeploy/messages.py
+++ b/lmdeploy/messages.py
@@ -293,8 +293,11 @@ def __post_init__(self):
assert self.device_type in [
'cuda', 'ascend', 'maca'
], (f'invalid device_type: {self.device_type}')
- if self.quant_policy > 0 and self.device_type != 'cuda':
- assert False, 'kv cache quantization only works for CUDA.'
+ if self.quant_policy > 0 and self.device_type not in [
+ 'cuda', 'ascend'
+ ]:
+ assert False, \
+ 'kv cache quantization only works for CUDA and ASCEND.'
class ResponseType(enum.Enum):
diff --git a/lmdeploy/model.py b/lmdeploy/model.py
index a4355ea131..f7b80ed102 100644
--- a/lmdeploy/model.py
+++ b/lmdeploy/model.py
@@ -46,6 +46,8 @@ class ChatTemplateConfig:
eoh (str | None): end of the user prompt
assistant (str | None): begin of the assistant prompt
eoa (str | None): end of the assistant prompt
+ tool (str | None): begin of the tool prompt
+ eotool (str | None): end of the tool prompt
capability: ('completion' | 'infilling' | 'chat' | 'python') = None
""" # noqa: E501
@@ -57,6 +59,8 @@ class ChatTemplateConfig:
eoh: Optional[str] = None
assistant: Optional[str] = None
eoa: Optional[str] = None
+ tool: Optional[str] = None
+ eotool: Optional[str] = None
separator: Optional[str] = None
capability: Optional[Literal['completion', 'infilling', 'chat',
'python']] = None
@@ -173,6 +177,8 @@ def __init__(self,
assistant='',
eoa='',
separator='',
+ tool='',
+ eotool='',
**kwargs):
super().__init__(**kwargs)
self.system = system
@@ -183,6 +189,8 @@ def __init__(self,
self.separator = separator
self.eosys = eosys
self.assistant = assistant
+ self.tool = tool
+ self.eotool = eotool
def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
@@ -223,10 +231,12 @@ def messages2prompt(self, messages, sequence_start=True, **kwargs):
return self.get_prompt(messages, sequence_start)
box_map = dict(user=self.user,
assistant=self.assistant,
- system=self.system)
+ system=self.system,
+ tool=self.tool)
eox_map = dict(user=self.eoh,
assistant=self.eoa + self.separator,
- system=self.eosys)
+ system=self.eosys,
+ tool=self.eotool)
ret = ''
if self.meta_instruction is not None and sequence_start:
if len(messages) and messages[0]['role'] != 'system':
@@ -819,7 +829,7 @@ class Llama3_1(Llama3):
def __init__(
self,
- tools="""# Tool Instructions
+ tool="""# Tool Instructions
- Always execute python code in messages that you share.
- When looking for real time information use relevant functions if available else fallback to brave_search
@@ -828,7 +838,7 @@ def __init__(
You have access to the following functions:
""", # noqa
- eotools="""
+ eotool="""
If a you choose to call a function ONLY reply in the following format:
<{start_tag}={function_name}>{parameters}{end_tag}
@@ -847,7 +857,7 @@ def __init__(
- Only call one function at a time
- Put the entire function call reply on one line"
- Always add your sources when using search results to answer the user query\n\n""", # noqa
- knowledge='Cutting Knowledge Date: December 2023\nToday Date: 23 Jul 2024\n\n',
+ knowledge='Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n',
meta_instruction='You are a helpful assistant.',
ipython='<|start_header_id|>ipython<|end_header_id|>\n\n',
eoi='<|eot_id|>',
@@ -858,8 +868,8 @@ def __init__(
**kwargs)
self.ipython = ipython
self.eoi = eoi
- self.tools = tools
- self.eotools = eotools
+ self.tool = tool
+ self.eotool = eotool
self.knowledge = knowledge
def messages2prompt(self,
@@ -899,7 +909,7 @@ def messages2prompt(self,
if tools is None:
ret += f'{self.system}{self.knowledge}{self.meta_instruction}{self.eosys}'
else:
- ret += f'{self.system}{self.knowledge}{self.tools}{tool_prompt}{self.eotools}{self.meta_instruction}{self.eosys}'
+ ret += f'{self.system}{self.knowledge}{self.tool}{tool_prompt}{self.eotool}{self.meta_instruction}{self.eosys}'
for message in messages:
role = message['role']
content = get_text(message['content'])
@@ -907,7 +917,7 @@ def messages2prompt(self,
or '' in content):
ret += f'{box_map[role]}{content}<|eom_id|>'
elif role == 'system' and tools is not None:
- ret += f'{box_map[role]}{self.tools}{tool_prompt}{self.eotools}{content}{eox_map[role]}'
+ ret += f'{box_map[role]}{self.tool}{tool_prompt}{self.eotool}{content}{eox_map[role]}'
else:
ret += f'{box_map[role]}{content}{eox_map[role]}'
if sequence_start and not isinstance(messages, str):
@@ -1921,5 +1931,5 @@ def best_match_model(query: str) -> Optional[str]:
for name, model in MODELS.module_dict.items():
if model.match(query):
return model.match(query)
- logger.warn(f'Did not find a chat template matching {query}.')
+ logger.warning(f'Did not find a chat template matching {query}.')
return 'base'
diff --git a/lmdeploy/pytorch/backends/attention.py b/lmdeploy/pytorch/backends/attention.py
index 92a0befbf4..f0e60d86ac 100644
--- a/lmdeploy/pytorch/backends/attention.py
+++ b/lmdeploy/pytorch/backends/attention.py
@@ -34,6 +34,7 @@ def __init__(
alibi: bool = None,
sliding_window: int = None,
logit_softcapping: float = None,
+ causal: bool = True,
**kwargs,
) -> None:
if scale is None:
@@ -53,6 +54,7 @@ def __init__(
self.alibi = alibi
self.sliding_window = sliding_window
self.logit_softcapping = logit_softcapping
+ self.causal = causal
@abstractmethod
def forward(
@@ -82,6 +84,7 @@ def build(
alibi: bool = False,
sliding_window: int = None,
logical_softcapping: float = None,
+ causal: bool = True,
**kwargs,
) -> AttentionImpl[T]:
"""build."""
diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py
index ef538f7a3d..263b419f1a 100644
--- a/lmdeploy/pytorch/backends/base.py
+++ b/lmdeploy/pytorch/backends/base.py
@@ -12,7 +12,8 @@
class OpType(Enum):
"""Layer type enumerate."""
- Attention = auto()
+ PagedAttention = auto()
+ FlashAttention = auto()
Linear = auto()
RotaryEmbedding = auto()
ApplyRotaryEmb = auto()
@@ -27,6 +28,9 @@ class OpType(Enum):
LinearW4A16 = auto()
SoftmaxTopK = auto()
FusedMoE = auto()
+ FusedMoEW8A8 = auto()
+ LinearBlockedF8 = auto()
+ FusedMoEBlockedF8 = auto()
class OpsBackend(ABC):
diff --git a/lmdeploy/pytorch/backends/blockedf8_modules.py b/lmdeploy/pytorch/backends/blockedf8_modules.py
new file mode 100644
index 0000000000..d79b41330c
--- /dev/null
+++ b/lmdeploy/pytorch/backends/blockedf8_modules.py
@@ -0,0 +1,39 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABC, abstractmethod
+from typing import Optional
+
+import torch
+
+
+class LinearBlockedF8Impl(ABC):
+ """linear BlockedF8 implementation api."""
+
+ def update_weights(self,
+ weight: torch.Tensor,
+ scale: torch.Tensor,
+ bias: Optional[torch.Tensor] = None):
+ """update weights."""
+ return weight, scale, bias
+
+ @abstractmethod
+ def forward(self,
+ x,
+ weight: torch.Tensor,
+ scale: torch.Tensor,
+ bias: Optional[torch.Tensor] = None,
+ all_reduce: bool = False):
+ """forward."""
+ raise NotImplementedError
+
+
+class LinearBlockedF8Builder(ABC):
+ """linear BlockedF8 implementation builder."""
+
+ @staticmethod
+ @abstractmethod
+ def build(in_features: int,
+ out_features: int,
+ bias: bool = True,
+ dtype: torch.dtype = None):
+ """build."""
+ raise NotImplementedError
diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py
index 1672803ff4..31546ae0e1 100644
--- a/lmdeploy/pytorch/backends/cuda/attention.py
+++ b/lmdeploy/pytorch/backends/cuda/attention.py
@@ -42,6 +42,7 @@ def __init__(
alibi: bool = False,
sliding_window: int = None,
logit_softcapping: float = None,
+ causal: bool = True,
**kwargs,
):
super().__init__(
@@ -53,8 +54,10 @@ def __init__(
alibi=alibi,
sliding_window=sliding_window,
logit_softcapping=logit_softcapping,
+ causal=causal,
**kwargs,
)
+ assert not (alibi and not causal)
from lmdeploy.pytorch.kernels.cuda import (alibi_paged_attention_fwd,
fill_kv_cache,
@@ -177,6 +180,7 @@ def forward(
window_size=self.sliding_window,
sm_scale=self.scale,
logit_softcapping=self.logit_softcapping,
+ causal=self.causal,
)
else:
self.alibi_paged_attention_fwd(
@@ -212,6 +216,7 @@ def build(
alibi: bool = False,
sliding_window: int = None,
logical_softcapping: float = None,
+ causal: bool = True,
**kwargs,
) -> TritonAttentionImpl:
"""build."""
@@ -223,4 +228,5 @@ def build(
alibi=alibi,
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
+ causal=causal,
**kwargs)
diff --git a/lmdeploy/pytorch/backends/cuda/awq_modules.py b/lmdeploy/pytorch/backends/cuda/awq_modules.py
index 8159bbf554..18b0150493 100644
--- a/lmdeploy/pytorch/backends/cuda/awq_modules.py
+++ b/lmdeploy/pytorch/backends/cuda/awq_modules.py
@@ -18,23 +18,14 @@ def wq_gemm_forward(
out_features=0,
):
"""wq gemm forward."""
- from awq.modules.linear.gemm import awq_ext
-
from lmdeploy.pytorch.kernels.cuda.awq_kernels import awq_linear
out_shape = x.shape[:-1] + (out_features, )
input_dtype = x.dtype
if input_dtype != torch.float16:
x = x.half()
- FP16_MATMUL_HEURISTIC_CONDITION = x.size(0) * x.size(1) >= 64
-
x = x.flatten(0, -2)
- if FP16_MATMUL_HEURISTIC_CONDITION:
- out = awq_linear(x, qweight, scales, qzeros)
- else:
- if not x.is_contiguous():
- x = x.contiguous()
- out = awq_ext.gemm_forward_cuda(x, qweight, scales, qzeros, 8)
+ out = awq_linear(x, qweight, scales, qzeros)
out = out + bias if bias is not None else out
out = out.reshape(out_shape)
diff --git a/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py b/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py
new file mode 100644
index 0000000000..8299ac2dfd
--- /dev/null
+++ b/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import (blocked_gemm_fp8,
+ quant_fp8)
+
+from ..blockedf8_modules import LinearBlockedF8Builder, LinearBlockedF8Impl
+
+
+class TritonLinearBlockedF8Impl(LinearBlockedF8Impl):
+ """triton linear blocked f8 implementation."""
+
+ def __init__(self,
+ in_features: int,
+ out_features: int,
+ block_size: int,
+ out_dtype: torch.dtype = torch.float16):
+ self.in_features = in_features
+ self.out_features = out_features
+ self.out_dtype = out_dtype
+ self.block_size = block_size
+
+ def forward(self,
+ x,
+ weight: torch.Tensor,
+ scale: torch.Tensor,
+ bias: Optional[torch.Tensor] = None,
+ all_reduce: bool = False):
+ """forward."""
+ x_shape = x.shape
+ x = x.flatten(0, -2)
+ input_quant, input_scale = quant_fp8(x,
+ self.block_size,
+ dtype=weight.dtype)
+
+ out = blocked_gemm_fp8(input_quant,
+ input_scale,
+ weight.t(),
+ scale.t(),
+ out_dtype=x.dtype)
+ if bias is not None:
+ out += bias
+
+ if all_reduce:
+ dist.all_reduce(out)
+
+ out = out.unflatten(0, x_shape[:-1])
+ return out
+
+
+class TritonLinearBlockedF8Builder(LinearBlockedF8Builder):
+ """triton linear blocked f8 implementation builder."""
+
+ @staticmethod
+ def build(in_features: int,
+ out_features: int,
+ block_size: int = 128,
+ bias: bool = True,
+ dtype: torch.dtype = None):
+ """build."""
+ return TritonLinearBlockedF8Impl(in_features, out_features, block_size,
+ dtype)
diff --git a/lmdeploy/pytorch/backends/cuda/flash_attention.py b/lmdeploy/pytorch/backends/cuda/flash_attention.py
new file mode 100644
index 0000000000..5d3925b744
--- /dev/null
+++ b/lmdeploy/pytorch/backends/cuda/flash_attention.py
@@ -0,0 +1,101 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import Tensor
+
+from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl
+
+
+class TritonFlashAttentionImpl(FlashAttentionImpl):
+ """triton flash attention implementation."""
+
+ def __init__(
+ self,
+ num_heads: int,
+ head_dim: int,
+ scale: float = None,
+ num_kv_heads: int = None,
+ v_head_dim: int = None,
+ causal: bool = True,
+ sliding_window: int = None,
+ logical_softcapping: float = None,
+ ):
+ if scale is None:
+ scale = 1.0 / (head_dim**0.5)
+
+ if num_kv_heads is None:
+ num_kv_heads = num_heads
+
+ if v_head_dim is None:
+ v_head_dim = head_dim
+
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.scale = scale
+ self.num_kv_heads = num_kv_heads
+ self.v_head_dim = v_head_dim
+ self.causal = causal
+ self.sliding_window = sliding_window
+ self.logical_softcapping = logical_softcapping
+
+ from lmdeploy.pytorch.kernels.cuda import flash_attention_fwd
+ self.flash_attention_fwd = flash_attention_fwd
+
+ def forward(self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ q_start_loc: Tensor,
+ q_seqlens: Tensor,
+ kv_start_loc: Tensor,
+ kv_seqlens: Tensor,
+ max_q_seqlen: int = None):
+ """forward."""
+
+ q_shape = query.shape
+ o_shape = q_shape[:-1] + (self.v_head_dim, )
+ out = query.new_empty(o_shape)
+ self.flash_attention_fwd(
+ query,
+ key,
+ value,
+ out,
+ q_start_loc=q_start_loc,
+ q_seqlens=q_seqlens,
+ kv_start_loc=kv_start_loc,
+ kv_seqlens=kv_seqlens,
+ max_seqlen=max_q_seqlen,
+ window_size=self.sliding_window,
+ sm_scale=self.scale,
+ logit_softcapping=self.logical_softcapping,
+ causal=self.causal,
+ kv_layout='shd',
+ )
+
+ return out
+
+
+class TritonFlashAttentionBuilder(FlashAttentionBuilder):
+ """triton attention builder."""
+
+ @staticmethod
+ def build(
+ num_heads: int,
+ head_dim: int,
+ scale: float = None,
+ num_kv_heads: int = None,
+ v_head_dim: int = None,
+ causal: bool = True,
+ sliding_window: int = None,
+ logical_softcapping: float = None,
+ **kwargs,
+ ) -> FlashAttentionImpl:
+ """build."""
+ return TritonFlashAttentionImpl(
+ num_heads=num_heads,
+ head_dim=head_dim,
+ scale=scale,
+ num_kv_heads=num_kv_heads,
+ v_head_dim=v_head_dim,
+ causal=causal,
+ sliding_window=sliding_window,
+ logical_softcapping=logical_softcapping,
+ )
diff --git a/lmdeploy/pytorch/backends/cuda/lora.py b/lmdeploy/pytorch/backends/cuda/lora.py
index 798d985715..b65a01df14 100644
--- a/lmdeploy/pytorch/backends/cuda/lora.py
+++ b/lmdeploy/pytorch/backends/cuda/lora.py
@@ -50,6 +50,15 @@ def forward(self,
"""forward."""
lora_input = self._make_packed_lora_input(x, ctx_mgr)
+ base_slice = adapter_info.base_slice
+ sliced_base = base_output[..., base_slice]
+
+ if base_output.is_contiguous():
+ kernel_output = sliced_base.flatten(0, -2)
+ cum = True
+ else:
+ kernel_output = None
+ cum = False
lora_out = fused_lora(
lora_input.x,
lora_A,
@@ -62,14 +71,14 @@ def forward(self,
adapter_ids=lora_input.adapter_ids,
max_rank=adapter_info.max_rank,
max_seqlen=lora_input.max_seq_len,
+ output=kernel_output,
+ cum=cum,
)
- base_slice = adapter_info.base_slice
- sliced_base = base_output[..., base_slice]
- lora_out = lora_out.reshape(sliced_base.shape)
- sliced_base.add_(lora_out)
- output = base_output
- return output
+ if not base_output.is_contiguous():
+ lora_out = lora_out.reshape(sliced_base.shape)
+ sliced_base.add_(lora_out)
+ return base_output
class TritonLoRABuilder(LoRABuilder):
diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py
index eb38401211..a913ca82fb 100644
--- a/lmdeploy/pytorch/backends/cuda/moe.py
+++ b/lmdeploy/pytorch/backends/cuda/moe.py
@@ -4,9 +4,17 @@
import torch
-from lmdeploy.pytorch.kernels.cuda import fused_moe
+from lmdeploy.pytorch.kernels.cuda import fused_moe, fused_moe_w8a8
+from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import \
+ fused_moe_blocked_fp8
+from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
+from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import \
+ per_token_quant_int8
+from lmdeploy.pytorch.models.q_modules import QTensor
-from ..moe import FusedMoEBuilder, FusedMoEImpl
+from ..moe import (FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl,
+ FusedMoEBuilder, FusedMoEImpl, FusedMoEW8A8Builder,
+ FusedMoEW8A8Impl)
class TritonFusedMoEImpl(FusedMoEImpl):
@@ -74,3 +82,185 @@ def build(top_k: int, num_experts: int, renormalize: bool = False):
return TritonFusedMoEImpl(top_k=top_k,
num_experts=num_experts,
renormalize=renormalize)
+
+
+class TritonFusedMoEW8A8Impl(FusedMoEW8A8Impl):
+ """triton fused moe w8a8 implementation."""
+
+ def __init__(self,
+ top_k: int,
+ num_experts: int,
+ renormalize: bool = False,
+ out_dtype: torch.dtype = torch.float16):
+ self.num_experts = num_experts
+ self.top_k = top_k
+ self.renormalize = renormalize
+ self.out_dtype = out_dtype
+
+ def update_weights(self, gate_up_weights: torch.Tensor,
+ down_weights: torch.Tensor, gate_up_scale: torch.Tensor,
+ down_scale: torch.Tensor):
+ gate_up_weights = gate_up_weights.transpose(1,
+ 2).contiguous().transpose(
+ 1, 2)
+ down_weights = down_weights.transpose(1,
+ 2).contiguous().transpose(1, 2)
+ return gate_up_weights, down_weights, gate_up_scale, down_scale
+
+ def support_ep(self):
+ """support expert parallelism."""
+ return True
+
+ def ep_expert_list(self, world_size: int, rank: int):
+ """experts list of current rank."""
+ num_experts = self.num_experts
+ expert_per_rank = (num_experts + world_size - 1) // world_size
+ first_expert = rank * expert_per_rank
+ last_expert = min(first_expert + expert_per_rank, num_experts)
+ return list(range(first_expert, last_expert))
+
+ def forward(self,
+ hidden_states: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.LongTensor,
+ gate_up_weights: torch.Tensor,
+ gate_up_scale: torch.Tensor,
+ down_weights: torch.Tensor,
+ down_scale: torch.Tensor,
+ expert_list: List[int] = None):
+ """forward."""
+
+ if isinstance(hidden_states, torch.Tensor):
+ hidden_states = hidden_states.contiguous()
+ input_quant, input_scale = per_token_quant_int8(
+ hidden_states, 1e-7)
+ else:
+ assert isinstance(hidden_states, QTensor)
+ input_quant, input_scale = (hidden_states.tensor,
+ hidden_states.scale)
+
+ expert_offset = 0
+ num_experts = None
+ if expert_list is not None and len(expert_list) != self.num_experts:
+ expert_offset = expert_list[0]
+ num_experts = self.num_experts
+ return fused_moe_w8a8(input_quant,
+ input_scale,
+ gate_up_weights,
+ gate_up_scale,
+ down_weights,
+ down_scale,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ topk=self.top_k,
+ out_dtype=self.out_dtype,
+ expert_offset=expert_offset,
+ num_experts=num_experts,
+ renormalize=self.renormalize)
+
+
+class TritonFusedMoEW8A8Builder(FusedMoEW8A8Builder):
+ """triton fused moe w8a8 builder."""
+
+ @staticmethod
+ def build(top_k: int,
+ num_experts: int,
+ renormalize: bool = False,
+ out_dtype: torch.dtype = torch.float16):
+ """build from mlp."""
+ return TritonFusedMoEW8A8Impl(top_k=top_k,
+ num_experts=num_experts,
+ renormalize=renormalize,
+ out_dtype=out_dtype)
+
+
+class TritonFusedMoEBlockedF8Impl(FusedMoEBlockedF8Impl):
+ """triton fused moe blocked f8 implementation."""
+
+ def __init__(self,
+ top_k: int,
+ num_experts: int,
+ renormalize: bool = False,
+ block_size: int = 128,
+ out_dtype: torch.dtype = torch.float16):
+ self.num_experts = num_experts
+ self.top_k = top_k
+ self.renormalize = renormalize
+ self.block_size = block_size
+ self.out_dtype = out_dtype
+
+ def update_weights(self, gate_up_weights: torch.Tensor,
+ down_weights: torch.Tensor, gate_up_scale: torch.Tensor,
+ down_scale: torch.Tensor):
+ gate_up_weights = gate_up_weights.transpose(1,
+ 2).contiguous().transpose(
+ 1, 2)
+ down_weights = down_weights.transpose(1,
+ 2).contiguous().transpose(1, 2)
+ return gate_up_weights, down_weights, gate_up_scale, down_scale
+
+ def support_ep(self):
+ """support expert parallelism."""
+ return True
+
+ def ep_expert_list(self, world_size: int, rank: int):
+ """experts list of current rank."""
+ num_experts = self.num_experts
+ expert_per_rank = (num_experts + world_size - 1) // world_size
+ first_expert = rank * expert_per_rank
+ last_expert = min(first_expert + expert_per_rank, num_experts)
+ return list(range(first_expert, last_expert))
+
+ def forward(self,
+ hidden_states: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.LongTensor,
+ gate_up_weights: torch.Tensor,
+ gate_up_scale: torch.Tensor,
+ down_weights: torch.Tensor,
+ down_scale: torch.Tensor,
+ expert_list: List[int] = None):
+ """forward."""
+ input_size = hidden_states.shape
+ hidden_states = hidden_states.flatten(0, -2)
+ input_quant, input_scale = quant_fp8(hidden_states,
+ self.block_size,
+ dtype=gate_up_weights.dtype)
+
+ expert_offset = 0
+ num_experts = None
+ if expert_list is not None and len(expert_list) != self.num_experts:
+ expert_offset = expert_list[0]
+ num_experts = self.num_experts
+ output = fused_moe_blocked_fp8(input_quant,
+ input_scale,
+ gate_up_weights,
+ gate_up_scale,
+ down_weights,
+ down_scale,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ topk=self.top_k,
+ out_dtype=hidden_states.dtype,
+ expert_offset=expert_offset,
+ num_experts=num_experts,
+ renormalize=self.renormalize)
+ output = output.unflatten(0, input_size[:-1])
+ return output
+
+
+class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder):
+ """triton fused moe blocked f8 builder."""
+
+ @staticmethod
+ def build(top_k: int,
+ num_experts: int,
+ renormalize: bool = False,
+ block_size: int = 128,
+ out_dtype: torch.dtype = torch.float16):
+ """build from mlp."""
+ return TritonFusedMoEBlockedF8Impl(top_k=top_k,
+ num_experts=num_experts,
+ renormalize=renormalize,
+ block_size=block_size,
+ out_dtype=out_dtype)
diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py
index bfd77a250f..cbc46352a5 100644
--- a/lmdeploy/pytorch/backends/cuda/op_backend.py
+++ b/lmdeploy/pytorch/backends/cuda/op_backend.py
@@ -23,9 +23,12 @@ def get_name() -> str:
@classmethod
def get_layer_impl_builder(cls, layer_type: OpType):
"""get cuda layer builder."""
- if layer_type == OpType.Attention:
+ if layer_type == OpType.PagedAttention:
from .attention import TritonAttentionBuilder
return TritonAttentionBuilder
+ elif layer_type == OpType.FlashAttention:
+ from .flash_attention import TritonFlashAttentionBuilder
+ return TritonFlashAttentionBuilder
elif layer_type == OpType.ApplyRotaryEmb:
from .apply_rotary_emb import TritonApplyRotaryEmbBuilder
return TritonApplyRotaryEmbBuilder
@@ -48,21 +51,20 @@ def get_layer_impl_builder(cls, layer_type: OpType):
from .activation import TritonSiluAndMulBuilder
return TritonSiluAndMulBuilder
elif layer_type == OpType.LinearW4A16:
- try:
- from awq.modules.linear.gemm import awq_ext # noqa: F401
- AWQ_INSTALLED = True
- except Exception:
- AWQ_INSTALLED = False
- if AWQ_INSTALLED:
- from .awq_modules import AwqLinearW4A16Builder
- return AwqLinearW4A16Builder
- else:
- logger.debug(
- f'Op {layer_type} fallback to default implementation.')
- return super().get_layer_impl_builder(layer_type)
+ from .awq_modules import AwqLinearW4A16Builder
+ return AwqLinearW4A16Builder
elif layer_type == OpType.FusedMoE:
from .moe import TritonFusedMoEBuilder
return TritonFusedMoEBuilder
+ elif layer_type == OpType.FusedMoEW8A8:
+ from .moe import TritonFusedMoEW8A8Builder
+ return TritonFusedMoEW8A8Builder
+ elif layer_type == OpType.FusedMoEBlockedF8:
+ from .moe import TritonFusedMoEBlockedF8Builder
+ return TritonFusedMoEBlockedF8Builder
+ elif layer_type == OpType.LinearBlockedF8:
+ from .blockedf8_modules import TritonLinearBlockedF8Builder
+ return TritonLinearBlockedF8Builder
else:
logger.debug(
f'Op {layer_type} fallback to default implementation.')
@@ -142,30 +144,30 @@ def update_step_context(cls, step_context):
medusa_attn_mask=step_context.medusa_attn_mask,
)
- cross_attn_metadata = None
- fill_seqlens = None
- if step_context.cross_attention_states is not None:
- fill_seqlens = torch.zeros_like(q_seqlens)
- for idx, state in enumerate(step_context.cross_attention_states):
- if state is not None:
- fill_seqlens[idx] = state.shape[-2]
+ cross_seqlens = step_context.cross_seqlens
cross_kv_seqlens = step_context.cross_kv_seqlens
- cross_kv_start_loc = None
- cross_kv_flatten_size = None
- if not step_context.is_decoding and cross_kv_seqlens is not None:
- cross_kv_start_loc = cross_kv_seqlens.cumsum(0) - cross_kv_seqlens
- cross_kv_flatten_size = cross_kv_seqlens.sum().item()
- cross_attn_metadata = attn_meta_cls(
- step_context.is_decoding,
- step_context.block_offsets,
- q_start_loc=q_start_loc,
- q_seqlens=q_seqlens,
- kv_start_loc=cross_kv_start_loc,
- kv_seqlens=cross_kv_seqlens,
- kv_flatten_size=cross_kv_flatten_size,
- fill_seqlens=fill_seqlens,
- quant_policy=step_context.kv_quant_policy,
- )
+ cross_attn_metadata = None
+ if cross_seqlens is not None:
+ fill_seqlens = cross_seqlens
+ if fill_seqlens.sum().item() == 0:
+ fill_seqlens = None
+ cross_kv_start_loc = None
+ cross_kv_flatten_size = None
+ if not step_context.is_decoding and cross_kv_seqlens is not None:
+ cross_kv_start_loc = cross_kv_seqlens.cumsum(
+ 0) - cross_kv_seqlens
+ cross_kv_flatten_size = cross_kv_seqlens.sum().item()
+ cross_attn_metadata = attn_meta_cls(
+ step_context.is_decoding,
+ step_context.block_offsets,
+ q_start_loc=q_start_loc,
+ q_seqlens=q_seqlens,
+ kv_start_loc=cross_kv_start_loc,
+ kv_seqlens=cross_kv_seqlens,
+ kv_flatten_size=cross_kv_flatten_size,
+ fill_seqlens=fill_seqlens,
+ quant_policy=step_context.kv_quant_policy,
+ )
step_context.attn_metadata = attn_metadata
step_context.cross_attn_metadata = cross_attn_metadata
diff --git a/lmdeploy/pytorch/backends/cuda/qmodules.py b/lmdeploy/pytorch/backends/cuda/qmodules.py
index 30f729a63f..13d9a47ddf 100644
--- a/lmdeploy/pytorch/backends/cuda/qmodules.py
+++ b/lmdeploy/pytorch/backends/cuda/qmodules.py
@@ -15,42 +15,62 @@
class TritonRMSNormW8A8Impl(RMSNormW8A8Impl):
"""triton RMS norm w8a8 implementation api."""
- def __init__(self, hidden_size: int, eps: float = 1e-6):
+ def __init__(self,
+ hidden_size: int,
+ eps: float = 1e-6,
+ quant_dtype: torch.dtype = torch.int8):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
+ self.quant_dtype = quant_dtype
def forward(self,
x: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor = None):
"""forward."""
- if residual is not None:
- x = x + residual
- residual = x
- hidden_states_quant, rms_scale = rms_norm_dynamic_quant(
- x, weight, self.eps)
- x = QTensor(hidden_states_quant, rms_scale)
if residual is None:
+ (x,
+ rms_scale) = rms_norm_dynamic_quant(x,
+ weight,
+ self.eps,
+ quant_dtype=self.quant_dtype)
+ x = QTensor(x, rms_scale)
return x
- return x, residual
+ else:
+ (x, rms_scale,
+ residual) = rms_norm_dynamic_quant(x,
+ weight,
+ self.eps,
+ residual=residual,
+ quant_dtype=self.quant_dtype)
+ x = QTensor(x, rms_scale)
+ return x, residual
class TritonRMSNormBuilder(RMSNormW8A8Builder):
"""triton RMS norm w8a8 implementation builder."""
@staticmethod
- def build(hidden_size: int, eps: float = 1e-6):
+ def build(hidden_size: int,
+ eps: float = 1e-6,
+ quant_dtype: torch.dtype = torch.int8):
"""build."""
- return TritonRMSNormW8A8Impl(hidden_size, eps)
+ return TritonRMSNormW8A8Impl(hidden_size, eps, quant_dtype)
class TritonLinearW8A8Impl(LinearW8A8Impl):
"""triton linear w8a8 implementation."""
- def __init__(self, in_features: int, out_features: int):
+ def __init__(self,
+ in_features: int,
+ out_features: int,
+ out_dtype: torch.dtype = torch.float16,
+ quant_dtype: torch.dtype = torch.int8):
self.in_features = in_features
self.out_features = out_features
+ self.out_dtype = out_dtype
+ self.quant_dtype = quant_dtype
def forward(self,
x,
@@ -60,8 +80,8 @@ def forward(self,
all_reduce: bool = False):
"""forward."""
if isinstance(x, torch.Tensor):
- x = x.contiguous()
- input_quant, input_scale = per_token_quant_int8(x, 1e-7)
+ input_quant, input_scale = per_token_quant_int8(
+ x, 1e-7, quant_dtype=self.quant_dtype)
else:
assert isinstance(x, QTensor)
input_quant, input_scale = x.tensor, x.scale
@@ -70,7 +90,7 @@ def forward(self,
weight,
input_scale,
scale,
- output_dtype=torch.float16,
+ output_dtype=self.out_dtype,
bias=bias)
if all_reduce:
@@ -85,6 +105,10 @@ class TritonLinearW8A8Builder(LinearW8A8Builder):
def build(in_features: int,
out_features: int,
bias: bool = True,
- dtype: torch.dtype = None):
+ dtype: torch.dtype = None,
+ quant_dtype: torch.dtype = torch.int8):
"""build."""
- return TritonLinearW8A8Impl(in_features, out_features)
+ return TritonLinearW8A8Impl(in_features,
+ out_features,
+ dtype,
+ quant_dtype=quant_dtype)
diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py
index f9664f13ff..e3c5dc4d5e 100644
--- a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py
+++ b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py
@@ -33,10 +33,17 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig,
dlinfer.graph.config.enable_graph_mode = True
self.patch_kernels_custom_op()
self.patch_kvcache_static_shape()
- self.model = torch.compile(self.model,
- fullgraph=True,
- dynamic=True,
- backend='atbgraph')
+ if hasattr(self.model, 'language_model'):
+ self.model.language_model = torch.compile(
+ self.model.language_model,
+ fullgraph=True,
+ dynamic=True,
+ backend='atbgraph')
+ else:
+ self.model = torch.compile(self.model,
+ fullgraph=True,
+ dynamic=True,
+ backend='atbgraph')
def check_enable_graph(self):
"""check enable graph."""
diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
index b6f544510b..588558f0d5 100644
--- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
+++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
@@ -1,5 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Tuple
+import itertools
+import os
+import re
+from pathlib import Path
+from typing import Dict, Tuple
import torch
@@ -11,6 +15,71 @@
logger = get_logger('lmdeploy')
+class AscendKVQuantMeta:
+ has_set_value: bool = False
+ quant_meta: Dict = {}
+
+ @classmethod
+ def set_value(cls, device: str, dtype: torch.dtype, record_file: str,
+ total_layers: int):
+ with open(record_file, 'r') as file:
+ data = file.read()
+ scale_offset_pairs = re.findall(
+ r'scale:\s*([\d\.\-]+)\s*offset:\s*(-?\d+)', data)
+ scale_offset_pairs = [(float(scale), float(offset))
+ for scale, offset in scale_offset_pairs]
+ k_scales, v_scales, kv_scales = [], [], []
+ k_zeros, v_zeros, kv_zeros = [], [], []
+ if len(scale_offset_pairs) == total_layers:
+ for scale, offset in scale_offset_pairs:
+ k_scales.append(
+ torch.tensor([scale], device=device, dtype=dtype))
+ v_scales.append(
+ torch.tensor([scale], device=device, dtype=dtype))
+ kv_scales.append(
+ torch.tensor([scale, scale], device=device, dtype=dtype))
+ k_zeros.append(
+ torch.tensor([offset], device=device, dtype=dtype))
+ v_zeros.append(
+ torch.tensor([offset], device=device, dtype=dtype))
+ kv_zeros.append(
+ torch.tensor([offset, offset], device=device, dtype=dtype))
+ elif len(scale_offset_pairs) == total_layers * 2:
+ for i in range(total_layers):
+ scale_k, offset_k = scale_offset_pairs[2 * i]
+ scale_v, offset_v = scale_offset_pairs[2 * i + 1]
+ k_scales.append(
+ torch.tensor([scale_k], device=device, dtype=dtype))
+ v_scales.append(
+ torch.tensor([scale_v], device=device, dtype=dtype))
+ kv_scales.append(
+ torch.tensor([scale_k, scale_v],
+ device=device,
+ dtype=dtype))
+ k_zeros.append(
+ torch.tensor([offset_k], device=device, dtype=dtype))
+ v_zeros.append(
+ torch.tensor([offset_v], device=device, dtype=dtype))
+ kv_zeros.append(
+ torch.tensor([offset_k, offset_v],
+ device=device,
+ dtype=dtype))
+ else:
+ raise ValueError(
+ f'num of scale_offset_pairs({len(scale_offset_pairs)}) '
+ f'must match num of total_layers({total_layers})')
+
+ cls.quant_meta.update({
+ 'k_scales': itertools.cycle(k_scales),
+ 'k_zeros': itertools.cycle(k_zeros),
+ 'v_scales': itertools.cycle(v_scales),
+ 'v_zeros': itertools.cycle(v_zeros),
+ 'kv_scales': itertools.cycle(kv_scales),
+ 'kv_zeros': itertools.cycle(kv_zeros)
+ })
+ cls.has_set_value = True
+
+
class AscendOpsBackend(DlinferOpsBackend):
"""ascend layer backend."""
enable_graph = False
@@ -164,6 +233,21 @@ def get_total_slots():
.repeat_interleave(step_context.q_seqlens, 0)
kv_seqlens = kv_seqlens_cpu
+ if not cls.enable_graph and step_context.kv_quant_policy == 8:
+ record_file = os.getenv('ASCEND_QUANT_RECORD_FILE')
+ assert record_file, 'please specify valid ASCEND_QUANT_RECORD_FILE'
+ path = Path(record_file)
+ is_path = path.is_absolute() or path.is_relative_to('/')
+ exists = path.exists()
+ if not (is_path and exists):
+ raise ValueError(
+ 'please specify valid ASCEND_QUANT_RECORD_FILE')
+ if not AscendKVQuantMeta.has_set_value:
+ total_layers = len(step_context.kv_caches)
+ AscendKVQuantMeta.set_value(step_context.block_offsets.device,
+ step_context.model_config.dtype,
+ record_file, total_layers)
+
attn_meta_cls = cls.get_attention_metadata_cls()
attn_metadata = attn_meta_cls(
step_context.is_decoding,
@@ -177,6 +261,8 @@ def get_total_slots():
is_unpaged_prefill=is_unpaged_prefill,
max_q_seq_len=max_q_seq_len,
max_kv_seq_len=max_kv_seq_len,
+ quant_policy=step_context.kv_quant_policy,
+ quant_meta=AscendKVQuantMeta.quant_meta,
)
step_context.attn_metadata = attn_metadata
diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py
index 0d666c9130..6b03403c84 100644
--- a/lmdeploy/pytorch/backends/dlinfer/attention.py
+++ b/lmdeploy/pytorch/backends/dlinfer/attention.py
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
-from typing import Optional, Sequence
+from typing import Dict, Optional, Sequence
from torch import Tensor
@@ -15,6 +15,7 @@ class DlinferAttentionMetadata(AttentionMetadata):
is_unpaged_prefill: Optional[bool] = None
max_q_seq_len: int = 1
max_kv_seq_len: int = 1
+ quant_meta: Dict = None
class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
@@ -30,8 +31,10 @@ def __init__(
alibi: bool = None,
sliding_window: int = None,
logit_softcapping: float = None,
+ causal: bool = True,
**kwargs,
):
+ assert causal
super().__init__(
num_heads,
head_size,
@@ -41,6 +44,7 @@ def __init__(
alibi,
sliding_window,
logit_softcapping,
+ causal=causal,
**kwargs,
)
@@ -74,10 +78,37 @@ def forward(
is_unpaged_prefill = attn_metadata.is_unpaged_prefill
max_q_seq_len = attn_metadata.max_q_seq_len
max_kv_seq_len = attn_metadata.max_kv_seq_len
+ quant_bits = attn_metadata.quant_policy
+ if attn_metadata.quant_meta is not None:
+ k_scales_zeros = [
+ next(attn_metadata.quant_meta['k_scales']),
+ next(attn_metadata.quant_meta['k_zeros'])
+ ] if 'k_scales' in attn_metadata.quant_meta else []
+ v_scales_zeros = [
+ next(attn_metadata.quant_meta['v_scales']),
+ next(attn_metadata.quant_meta['v_zeros'])
+ ] if 'v_scales' in attn_metadata.quant_meta else []
+ kv_scales = next(
+ attn_metadata.quant_meta['kv_scales']
+ ) if 'kv_scales' in attn_metadata.quant_meta else None
+ kv_zeros = next(
+ attn_metadata.quant_meta['kv_zeros']
+ ) if 'kv_zeros' in attn_metadata.quant_meta else None
+ else:
+ k_scales_zeros = []
+ v_scales_zeros = []
+ kv_scales = None
+ kv_zeros = None
# fill kv cache
- k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache,
- kv_start_indices)
+ k_cache, v_cache = self.fill_kv_cache(key,
+ value,
+ k_cache,
+ v_cache,
+ kv_start_indices,
+ k_scales_zeros=k_scales_zeros,
+ v_scales_zeros=v_scales_zeros,
+ quant_bits=quant_bits)
if inplace:
attn_output = query[..., :self.v_head_size]
@@ -103,6 +134,9 @@ def forward(
block_size=block_size,
attn_mask=attn_mask,
is_unpaged_prefill=is_unpaged_prefill,
+ kv_scales=kv_scales,
+ kv_zeros=kv_zeros,
+ quant_bits=quant_bits,
)
return attn_output
@@ -121,6 +155,7 @@ def build(
alibi_scale: float = None,
sliding_window: int = None,
logical_softcapping: float = None,
+ causal: bool = True,
**kwargs,
) -> DlinferAttentionImpl:
"""build."""
@@ -132,4 +167,5 @@ def build(
alibi_scale=alibi_scale,
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
+ causal=causal,
**kwargs)
diff --git a/lmdeploy/pytorch/backends/dlinfer/flash_attention.py b/lmdeploy/pytorch/backends/dlinfer/flash_attention.py
new file mode 100644
index 0000000000..d0d9ddbb26
--- /dev/null
+++ b/lmdeploy/pytorch/backends/dlinfer/flash_attention.py
@@ -0,0 +1,94 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import Tensor
+
+from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl
+
+
+class DlinferFlashAttentionImpl(FlashAttentionImpl):
+ """dlinfer flash attention implementation."""
+
+ def __init__(
+ self,
+ num_heads: int,
+ head_dim: int,
+ scale: float = None,
+ num_kv_heads: int = None,
+ v_head_dim: int = None,
+ causal: bool = True,
+ sliding_window: int = None,
+ logical_softcapping: float = None,
+ ):
+ if scale is None:
+ scale = 1.0 / (head_dim**0.5)
+ if num_kv_heads is None:
+ num_kv_heads = num_heads
+ if v_head_dim is None:
+ v_head_dim = head_dim
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.scale = scale
+ self.num_kv_heads = num_kv_heads
+ self.v_head_dim = v_head_dim
+ self.causal = causal
+ self.sliding_window = sliding_window
+ self.logical_softcapping = logical_softcapping
+ from lmdeploy.pytorch.kernels.dlinfer import flash_attention_fwd
+ self.flash_attention_fwd = flash_attention_fwd
+
+ def forward(self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ q_start_loc: Tensor,
+ q_seqlens: Tensor,
+ kv_start_loc: Tensor,
+ kv_seqlens: Tensor,
+ max_q_seqlen: int = None):
+ """forward."""
+ q_shape = query.shape
+ o_shape = q_shape[:-1] + (self.v_head_dim, )
+ out = query.new_empty(o_shape)
+ self.flash_attention_fwd(
+ query,
+ key,
+ value,
+ out,
+ q_start_loc=q_start_loc,
+ q_seqlens=q_seqlens,
+ kv_start_loc=kv_start_loc,
+ kv_seqlens=kv_seqlens,
+ max_q_seqlen=max_q_seqlen,
+ window_size=self.sliding_window,
+ sm_scale=self.scale,
+ logit_softcapping=self.logical_softcapping,
+ causal=self.causal,
+ )
+ return out
+
+
+class DlinferFlashAttentionBuilder(FlashAttentionBuilder):
+ """dlinfer attention builder."""
+
+ @staticmethod
+ def build(
+ num_heads: int,
+ head_dim: int,
+ scale: float = None,
+ num_kv_heads: int = None,
+ v_head_dim: int = None,
+ causal: bool = True,
+ sliding_window: int = None,
+ logical_softcapping: float = None,
+ **kwargs,
+ ) -> FlashAttentionImpl:
+ """build."""
+ return DlinferFlashAttentionImpl(
+ num_heads=num_heads,
+ head_dim=head_dim,
+ scale=scale,
+ num_kv_heads=num_kv_heads,
+ v_head_dim=v_head_dim,
+ causal=causal,
+ sliding_window=sliding_window,
+ logical_softcapping=logical_softcapping,
+ )
diff --git a/lmdeploy/pytorch/backends/dlinfer/moe.py b/lmdeploy/pytorch/backends/dlinfer/moe.py
index 6ada730fbe..ff986c5765 100644
--- a/lmdeploy/pytorch/backends/dlinfer/moe.py
+++ b/lmdeploy/pytorch/backends/dlinfer/moe.py
@@ -47,8 +47,8 @@ def forward(self,
down_weights: torch.Tensor,
expert_list: List[int] = None):
"""forward."""
- return fused_moe(hidden_states, self.top_k, topk_ids, topk_weights,
- gate_up_weights, down_weights)
+ return fused_moe(hidden_states, gate_up_weights, down_weights,
+ topk_weights, topk_ids, self.top_k, self.renormalize)
class DlinferFusedMoEBuilder(FusedMoEBuilder):
diff --git a/lmdeploy/pytorch/backends/dlinfer/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/op_backend.py
index 52a8830595..a0f04f34b1 100644
--- a/lmdeploy/pytorch/backends/dlinfer/op_backend.py
+++ b/lmdeploy/pytorch/backends/dlinfer/op_backend.py
@@ -22,9 +22,12 @@ def get_name() -> str:
@classmethod
def get_layer_impl_builder(cls, layer_type: OpType):
"""get dlinfer layer builder."""
- if layer_type == OpType.Attention:
+ if layer_type == OpType.PagedAttention:
from .attention import DlinferAttentionBuilder
return DlinferAttentionBuilder
+ elif layer_type == OpType.FlashAttention:
+ from .flash_attention import DlinferFlashAttentionBuilder
+ return DlinferFlashAttentionBuilder
elif layer_type == OpType.ApplyRotaryEmb:
from .apply_rotary_emb import DlinferApplyRotaryEmbBuilder
return DlinferApplyRotaryEmbBuilder
diff --git a/lmdeploy/pytorch/backends/flash_attention.py b/lmdeploy/pytorch/backends/flash_attention.py
new file mode 100644
index 0000000000..bed3af8d68
--- /dev/null
+++ b/lmdeploy/pytorch/backends/flash_attention.py
@@ -0,0 +1,40 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABC, abstractmethod
+
+from torch import Tensor
+
+
+class FlashAttentionImpl(ABC):
+ """FlashAttention implementation."""
+
+ def forward(self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ q_start_loc: Tensor,
+ q_seqlens: Tensor,
+ kv_start_loc: Tensor,
+ kv_seqlens: Tensor,
+ max_q_seqlen: int = None):
+ """forward."""
+ raise NotImplementedError
+
+
+class FlashAttentionBuilder(ABC):
+ """FlashAttention implementation builder."""
+
+ @staticmethod
+ @abstractmethod
+ def build(
+ num_heads: int,
+ head_dim: int,
+ scale: float = None,
+ num_kv_heads: int = None,
+ v_head_dim: int = None,
+ causal: bool = True,
+ sliding_window: int = None,
+ logical_softcapping: float = None,
+ **kwargs,
+ ) -> FlashAttentionImpl:
+ """build."""
+ raise NotImplementedError
diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py
index 9ab66b26a2..9347995e0b 100644
--- a/lmdeploy/pytorch/backends/graph_runner.py
+++ b/lmdeploy/pytorch/backends/graph_runner.py
@@ -46,3 +46,26 @@ def prepare_inputs_for_generation(
inputs_embeds,
context,
)
+
+ def update_model_metas(
+ self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: torch.Tensor = None,
+ context: StepContext = None,
+ ):
+ """prepare inputs."""
+ if hasattr(self.model, 'update_model_metas'):
+ return self.model.update_model_metas(
+ past_key_values,
+ inputs_embeds,
+ context,
+ )
+
+ return None
+
+ def get_input_processor(self):
+ """get input processor."""
+ if hasattr(self.model, 'get_input_processor'):
+ return self.model.get_input_processor()
+ else:
+ return None
diff --git a/lmdeploy/pytorch/backends/moe.py b/lmdeploy/pytorch/backends/moe.py
index 8e7977625e..4501e52c0b 100644
--- a/lmdeploy/pytorch/backends/moe.py
+++ b/lmdeploy/pytorch/backends/moe.py
@@ -60,3 +60,93 @@ class FusedMoEBuilder(ABC):
def build(top_k: int, num_experts: int, renormalize: bool = False):
"""build from mlp."""
raise NotImplementedError
+
+
+class FusedMoEW8A8Impl(ABC):
+ """fused moe w8a8 implementation."""
+
+ def update_weights(self, gate_up_weights: torch.Tensor,
+ down_weights: torch.Tensor, gate_up_scale: torch.Tensor,
+ down_scale: torch.Tensor):
+ """update weights."""
+ return gate_up_weights, down_weights, gate_up_scale, down_scale
+
+ def support_ep(self):
+ """support expert parallelism."""
+ return False
+
+ def ep_expert_list(self, world_size: int, rank: int):
+ """experts list of current rank."""
+ raise NotImplementedError('Not Implemented.')
+
+ @abstractmethod
+ def forward(self,
+ hidden_states: torch.Tensor,
+ input_scale: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.LongTensor,
+ gate_up_weights: torch.Tensor,
+ gate_up_scale: torch.Tensor,
+ down_weights: torch.Tensor,
+ down_scale: torch.Tensor,
+ expert_list: List[int] = None):
+ """forward."""
+ raise NotImplementedError
+
+
+class FusedMoEW8A8Builder(ABC):
+ """fused moe w8a8 builder."""
+
+ @staticmethod
+ @abstractmethod
+ def build(top_k: int,
+ num_experts: int,
+ renormalize: bool = False,
+ out_dtype: torch.dtype = torch.float16):
+ """build from mlp."""
+ raise NotImplementedError
+
+
+class FusedMoEBlockedF8Impl(ABC):
+ """fused moe blocked f8 implementation."""
+
+ def update_weights(self, gate_up_weights: torch.Tensor,
+ down_weights: torch.Tensor, gate_up_scale: torch.Tensor,
+ down_scale: torch.Tensor):
+ """update weights."""
+ return gate_up_weights, down_weights, gate_up_scale, down_scale
+
+ def support_ep(self):
+ """support expert parallelism."""
+ return False
+
+ def ep_expert_list(self, world_size: int, rank: int):
+ """experts list of current rank."""
+ raise NotImplementedError('Not Implemented.')
+
+ @abstractmethod
+ def forward(self,
+ hidden_states: torch.Tensor,
+ input_scale: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.LongTensor,
+ gate_up_weights: torch.Tensor,
+ gate_up_scale: torch.Tensor,
+ down_weights: torch.Tensor,
+ down_scale: torch.Tensor,
+ expert_list: List[int] = None):
+ """forward."""
+ raise NotImplementedError
+
+
+class FusedMoEBlockedF8Builder(ABC):
+ """fused moe blocked f8 builder."""
+
+ @staticmethod
+ @abstractmethod
+ def build(top_k: int,
+ num_experts: int,
+ renormalize: bool = False,
+ out_dtype: torch.dtype = torch.float16):
+ """build from mlp."""
+ raise NotImplementedError
diff --git a/lmdeploy/pytorch/backends/qmodules.py b/lmdeploy/pytorch/backends/qmodules.py
index a61941b37d..e877a4ca6b 100644
--- a/lmdeploy/pytorch/backends/qmodules.py
+++ b/lmdeploy/pytorch/backends/qmodules.py
@@ -37,7 +37,9 @@ class RMSNormW8A8Builder(ABC):
@staticmethod
@abstractmethod
- def build(hidden_size: int, eps: float = 1e-6):
+ def build(hidden_size: int,
+ eps: float = 1e-6,
+ quant_dtype: torch.dtype = torch.int8):
"""build."""
raise NotImplementedError
@@ -71,6 +73,7 @@ class LinearW8A8Builder(ABC):
def build(in_features: int,
out_features: int,
bias: bool = True,
- dtype: torch.dtype = None):
+ dtype: torch.dtype = None,
+ quant_dtype: torch.dtype = torch.int8):
"""build."""
raise NotImplementedError
diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py
index 7d72438224..bc95a32be6 100644
--- a/lmdeploy/pytorch/check_env/__init__.py
+++ b/lmdeploy/pytorch/check_env/__init__.py
@@ -1,277 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from logging import Logger
-from typing import List
-
-from lmdeploy.utils import get_logger
-
-
-def _handle_exception(e: Exception,
- mod_name: str,
- logger: Logger,
- message: str = None):
- red_color = '\033[31m'
- reset_color = '\033[0m'
- if message is None:
- message = 'Please ensure it has been installed correctly.'
- logger.debug('Exception', exc_info=1)
- logger.error(f'{type(e).__name__}: {e}')
- logger.error(f'{red_color}'
- f'<{mod_name}> test failed!\n'
- f'{message}'
- f'{reset_color}')
- exit(1)
+from .base import BaseChecker # noqa: F401
def check_env_deeplink(device_type: str):
"""check Deeplink environment."""
- try_import_deeplink(device_type)
+ from .deeplink import DeeplinkChecker
+ checker = DeeplinkChecker(device_type)
+ checker.handle()
def try_import_deeplink(device_type: str):
- """import dlinfer if specific device_type is set."""
- deeplink_device_type_list = [
- 'ascend',
- 'npu',
- 'maca',
- ]
- if device_type in deeplink_device_type_list:
- logger = get_logger('lmdeploy')
- try:
- import dlinfer.framework.lmdeploy_ext # noqa: F401
- except Exception as e:
- _handle_exception(e, 'PyTorch', logger)
-
-
-def check_env_torch():
- """check PyTorch environment."""
- logger = get_logger('lmdeploy')
-
- try:
- logger.debug('Checking environment.')
- import torch
-
- a = torch.tensor([1, 2], device='cuda')
- b = a.new_tensor([3, 4], device='cuda')
- c = a + b
- torch.testing.assert_close(c, a.new_tensor([4, 6]))
- except Exception as e:
- _handle_exception(e, 'PyTorch', logger)
-
-
-MAX_TRITON_VERSION = '3.0.0'
-
-
-def check_env_triton(device: str):
- """check OpenAI Triton environment."""
- from packaging import version
- logger = get_logger('lmdeploy')
-
- msg = (
- 'Please ensure that your device is functioning properly with .\n' # noqa: E501
- 'You can verify your environment by running '
- '`python -m lmdeploy.pytorch.check_env.triton_custom_add`.')
- try:
- logger.debug('Checking environment.')
- import torch
- import triton
- triton_version = version.parse(triton.__version__)
- if triton_version > version.parse(MAX_TRITON_VERSION):
- logger.warning(
- f'Engine has not been tested on triton>{MAX_TRITON_VERSION}.')
-
- from .triton_custom_add import custom_add
- a = torch.tensor([1, 2], device='cuda')
- b = a.new_tensor([3, 4], device='cuda')
- c = custom_add(a, b)
- torch.testing.assert_close(c, a + b)
- except RuntimeError as e:
- ptxas_error = 'device kernel image is invalid'
- if len(e.args) > 0 and ptxas_error in e.args[0]:
- msg = (
- 'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n' # noqa: E501
- 'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209' # noqa: E501
- ' or reinstall the driver.')
- _handle_exception(e, 'Triton', logger, msg)
- except Exception as e:
- _handle_exception(e, 'Triton', logger, msg)
-
- if device == 'cuda':
- device_cap = torch.cuda.get_device_capability()
- TRITON_VER_231 = version.parse('2.3.1')
-
- if device_cap[0] <= 7:
- if triton_version <= TRITON_VER_231:
- err = RuntimeError(
- 'Attention triton kernel does not fully support '
- 'triton<3.0.0 on device with capability<8. '
- 'Please upgrade your triton version.')
- _handle_exception(err, 'Triton', logger)
-
-
-def check_env(device_type: str):
- """check all environment."""
- logger = get_logger('lmdeploy')
- logger.info('Checking environment for PyTorch Engine.')
+ """check Deeplink environment."""
check_env_deeplink(device_type)
- check_env_torch()
- if device_type == 'cuda':
- check_env_triton('cuda')
-
-
-MIN_TRANSFORMERS_VERSION = '4.33.0'
-MAX_TRANSFORMERS_VERSION = '4.44.1'
-
-
-def check_awq(hf_config, device_type):
- """check awq support."""
- logger = get_logger('lmdeploy')
- if device_type == 'cuda':
- quantization_config = getattr(hf_config, 'quantization_config', dict())
- quant_method = quantization_config.get('quant_method', None)
- if quant_method != 'awq':
- return
- try:
- import awq # noqa
- except Exception as e:
- _handle_exception(e, 'autoawq', logger)
-
- try:
- import awq_ext # noqa
- except Exception:
- logger.debug('Exception:', exc_info=1)
- logger.warning('Failed to import `awq_ext`. '
- 'Try reinstall it from source: '
- 'https://github.com/casper-hansen/AutoAWQ_kernels')
-
-
-def check_transformers_version(model_path: str,
- trust_remote_code: bool = True,
- dtype: str = 'auto',
- device_type: str = 'cuda'):
- """check transformers version."""
- from packaging import version
- logger = get_logger('lmdeploy')
-
- def __check_transformers_version():
- """check transformers version."""
- logger.debug('Checking version.')
- trans_version = None
- try:
- import transformers
- trans_version = version.parse(transformers.__version__)
- min_version = version.parse(MIN_TRANSFORMERS_VERSION)
- max_version = version.parse(MAX_TRANSFORMERS_VERSION)
- if trans_version < min_version or trans_version > max_version:
- logger.warning('LMDeploy requires transformers version: '
- f'[{MIN_TRANSFORMERS_VERSION} ~ '
- f'{MAX_TRANSFORMERS_VERSION}], '
- 'but found version: '
- f'{transformers.__version__}')
- except Exception as e:
- _handle_exception(e, 'transformers', logger)
- return transformers, trans_version
-
- def __check_config(trans_version):
- """check config."""
- logger.debug('Checking AutoConfig.from_pretrained.')
- try:
- from transformers import AutoConfig
- config = AutoConfig.from_pretrained(
- model_path, trust_remote_code=trust_remote_code)
- except Exception as e:
- message = (
- f'Load model config with transformers=={trans_version}'
- ' failed. '
- 'Please make sure model can be loaded with transformers API.')
- _handle_exception(e, 'transformers', logger, message=message)
- return config
-
- def __check_model_transformers_version(config, trans_version):
- """check model transformers version."""
- logger.debug('Checking required transformers version.')
- try:
- model_trans_version = getattr(config, 'transformers_version', None)
- if model_trans_version is not None:
- model_trans_version = version.parse(model_trans_version)
- assert trans_version >= model_trans_version, \
- 'Version mismatch.'
- except Exception as e:
- message = (f'model `{model_path}` requires '
- f'transformers version {model_trans_version} '
- f'but transformers {trans_version} is installed.')
- _handle_exception(e, 'transformers', logger, message=message)
-
- def __check_model_dtype_support(config, device_type):
- """Checking model dtype support."""
- logger.debug('Checking dtype support.')
-
- import torch
-
- from lmdeploy.pytorch.config import ModelConfig
- from lmdeploy.utils import is_bf16_supported
-
- try:
- model_config = ModelConfig.from_hf_config(config,
- model_path=model_path,
- dtype=dtype)
- if model_config.dtype == torch.bfloat16:
- assert is_bf16_supported(device_type), (
- 'bf16 is not supported on your device')
- except AssertionError as e:
- message = (
- f'Your device does not support `{model_config.dtype}`. '
- 'You can set `dtype` to float16 in PyTorchEngineConfig or '
- '`--dtype float16` to api_server.\n'
- 'Note that this might have negative effect!')
- _handle_exception(e, 'Model', logger, message=message)
- except Exception as e:
- message = (f'Checking failed with error {e}',
- 'Please send issue to LMDeploy with error logs.')
- _handle_exception(e, 'Model', logger, message=message)
-
- return model_config
-
- _, trans_version = __check_transformers_version()
- config = __check_config(trans_version)
- __check_model_transformers_version(config, trans_version)
- __check_model_dtype_support(config, device_type)
- check_awq(config, device_type)
-
-
-def check_model(model_path: str,
- trust_remote_code: bool = True,
- dtype: str = 'auto',
- device_type: str = 'cuda'):
- """check model requirements."""
- logger = get_logger('lmdeploy')
- logger.info('Checking model.')
- check_transformers_version(model_path, trust_remote_code, dtype,
- device_type)
-
-
-def check_adapter(path: str):
- """check adapter."""
- logger = get_logger('lmdeploy')
- logger.debug(f'Checking : {path}.')
-
- try:
- from peft import PeftConfig
- PeftConfig.from_pretrained(path)
- except Exception as e:
- message = ('Please make sure the adapter can be loaded with '
- '`peft.PeftConfig.from_pretrained`\n')
- err_msg = '' if len(e.args) == 0 else e.args[0]
- if 'got an unexpected keyword argument' in err_msg:
- message += ('Or try remove all unexpected keywords '
- 'in `adapter_config.json`.')
- _handle_exception(e, 'Model', logger, message=message)
-
-
-def check_adapters(adapter_paths: List[str]):
- """check adapters."""
- if len(adapter_paths) <= 0:
- return
- logger = get_logger('lmdeploy')
- logger.info('Checking adapters.')
- for path in adapter_paths:
- check_adapter(path)
diff --git a/lmdeploy/pytorch/check_env/adapter.py b/lmdeploy/pytorch/check_env/adapter.py
new file mode 100644
index 0000000000..bcaf5fd0e3
--- /dev/null
+++ b/lmdeploy/pytorch/check_env/adapter.py
@@ -0,0 +1,31 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import BaseChecker
+
+
+class AdapterChecker(BaseChecker):
+ """check adapter is available."""
+
+ def __init__(self, adapter_path: str, logger=None):
+ super().__init__(logger)
+ self.adapter_path = adapter_path
+
+ def check(self):
+ """check."""
+ path = self.adapter_path
+
+ try:
+ import peft # noqa: F401
+ except Exception as e:
+ self.log_and_exit(e, 'Adapter', message='Failed to import peft.')
+
+ try:
+ from peft import PeftConfig
+ PeftConfig.from_pretrained(path)
+ except Exception as e:
+ message = ('Please make sure the adapter can be loaded with '
+ '`peft.PeftConfig.from_pretrained`\n')
+ err_msg = '' if len(e.args) == 0 else e.args[0]
+ if 'got an unexpected keyword argument' in err_msg:
+ message += ('Or try remove all unexpected keywords '
+ 'in `adapter_config.json`.')
+ self.log_and_exit(e, 'Adapter', message=message)
diff --git a/lmdeploy/pytorch/check_env/base.py b/lmdeploy/pytorch/check_env/base.py
new file mode 100644
index 0000000000..ed5e5a600f
--- /dev/null
+++ b/lmdeploy/pytorch/check_env/base.py
@@ -0,0 +1,62 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from logging import Logger
+from typing import List
+
+from lmdeploy.utils import get_logger
+
+RED_COLOR = '\033[31m'
+RESET_COLOR = '\033[0m'
+
+
+def _red_text(text: str):
+ """red text."""
+ return f'{RED_COLOR}{text}{RESET_COLOR}'
+
+
+class BaseChecker:
+ """base checker."""
+
+ def __init__(self, logger: Logger = None):
+ if logger is None:
+ logger = get_logger('lmdeploy')
+ self.logger = logger
+ self._is_passed = False
+ self._required_checker: List[BaseChecker] = list()
+
+ def get_logger(self):
+ """get logger."""
+ return self.logger
+
+ def register_required_checker(self, checker: 'BaseChecker'):
+ """register_required."""
+ self._required_checker.append(checker)
+
+ def handle(self):
+ """handle check."""
+ is_passed = getattr(self, '_is_passed', False)
+ if not is_passed:
+ checker_name = type(self).__name__
+ self.logger.debug(f'Checking <{checker_name}>:')
+ for checker in self._required_checker:
+ checker.handle()
+ self.check()
+ self.is_passed = True
+
+ def log_and_exit(self,
+ e: Exception = None,
+ mod_name: str = None,
+ message: str = None):
+ logger = self.logger
+ if mod_name is None:
+ mod_name = type(self).__name__
+ if message is None:
+ message = 'Please check your environment.'
+ logger.debug('Exception', exc_info=1)
+ if e is not None:
+ logger.error(f'{type(e).__name__}: {e}')
+ logger.error(f'<{mod_name}> check failed!\n{_red_text(message)}')
+ exit(1)
+
+ def check(self):
+ """check."""
+ raise NotImplementedError('check not implemented.')
diff --git a/lmdeploy/pytorch/check_env/deeplink.py b/lmdeploy/pytorch/check_env/deeplink.py
new file mode 100644
index 0000000000..74ab5a7b87
--- /dev/null
+++ b/lmdeploy/pytorch/check_env/deeplink.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import BaseChecker
+
+deeplink_device_type_list = [
+ 'ascend',
+ 'npu',
+ 'maca',
+]
+
+
+class DeeplinkChecker(BaseChecker):
+ """check pytorch is available."""
+
+ def __init__(self, device_type: str, logger=None) -> None:
+ super().__init__(logger=logger)
+ self.device_type = device_type
+
+ def check(self):
+ """check."""
+ device_type = self.device_type
+ if device_type in deeplink_device_type_list:
+ try:
+ import dlinfer.framework.lmdeploy_ext # noqa: F401
+ except Exception as e:
+ self.log_and_exit(e, 'dlinfer', 'dlinfer is not available.')
diff --git a/lmdeploy/pytorch/check_env/model.py b/lmdeploy/pytorch/check_env/model.py
new file mode 100644
index 0000000000..79d8d26e3c
--- /dev/null
+++ b/lmdeploy/pytorch/check_env/model.py
@@ -0,0 +1,87 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from packaging import version
+
+from .base import BaseChecker
+
+
+class ModelChecker(BaseChecker):
+ """check model is available."""
+
+ def __init__(self,
+ model_path: str,
+ trust_remote_code: bool,
+ dtype: str,
+ device_type: str,
+ logger=None) -> None:
+ super().__init__(logger=logger)
+ self.model_path = model_path
+ self.trust_remote_code = trust_remote_code
+ self.device_type = device_type
+ self.dtype = dtype
+
+ def check_config(self, trans_version):
+ """check config."""
+ model_path = self.model_path
+ trust_remote_code = self.trust_remote_code
+ try:
+ from transformers import AutoConfig
+ config = AutoConfig.from_pretrained(
+ model_path, trust_remote_code=trust_remote_code)
+ except Exception as e:
+ message = (
+ f'Load model config with transformers=={trans_version}'
+ ' failed. '
+ 'Please make sure model can be loaded with transformers API.')
+ self.log_and_exit(e, 'transformers', message=message)
+ return config
+
+ def check_trans_version(self, config, trans_version):
+ """check transformers version."""
+ model_path = self.model_path
+ try:
+ model_trans_version = getattr(config, 'transformers_version', None)
+ if model_trans_version is not None:
+ model_trans_version = version.parse(model_trans_version)
+ assert trans_version >= model_trans_version, (
+ 'Version mismatch.')
+ except Exception as e:
+ message = (f'model `{model_path}` requires '
+ f'transformers version {model_trans_version} '
+ f'but transformers {trans_version} is installed.')
+ self.log_and_exit(e, 'transformers', message=message)
+
+ def check_dtype(self, config):
+ """check dtype."""
+ logger = self.get_logger()
+ model_path = self.model_path
+ device_type = self.device_type
+ dtype = self.dtype
+ try:
+ import torch
+
+ from lmdeploy.pytorch.config import ModelConfig
+ from lmdeploy.utils import is_bf16_supported
+ model_config = ModelConfig.from_hf_config(config,
+ model_path=model_path,
+ dtype=dtype)
+ if model_config.dtype == torch.bfloat16:
+ if not is_bf16_supported(device_type):
+ logger.warning('Device does not support bfloat16.')
+ except Exception as e:
+ message = (f'Checking failed with error {e}',
+ 'Please send issue to LMDeploy with error logs.')
+ self.log_and_exit(e, 'Model', message=message)
+
+ def check(self):
+ """check."""
+ import transformers
+ trans_version = version.parse(transformers.__version__)
+
+ # config
+ config = self.check_config(trans_version)
+
+ # transformers version
+ self.check_trans_version(config, trans_version)
+
+ # dtype check
+ self.check_dtype(config)
diff --git a/lmdeploy/pytorch/check_env/torch.py b/lmdeploy/pytorch/check_env/torch.py
new file mode 100644
index 0000000000..14b24e04a0
--- /dev/null
+++ b/lmdeploy/pytorch/check_env/torch.py
@@ -0,0 +1,21 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import BaseChecker
+
+
+class TorchChecker(BaseChecker):
+ """check pytorch is available."""
+
+ def __init__(self, device: str = 'cuda', logger=None) -> None:
+ super().__init__(logger=logger)
+ self.device = device
+
+ def check(self):
+ """check."""
+ try:
+ import torch
+ a = torch.tensor([1, 2], device=self.device)
+ b = a.new_tensor([3, 4], device=self.device)
+ c = a + b
+ torch.testing.assert_close(c, a.new_tensor([4, 6]))
+ except Exception as e:
+ self.log_and_exit(e, 'PyTorch', 'PyTorch is not available.')
diff --git a/lmdeploy/pytorch/check_env/transformers.py b/lmdeploy/pytorch/check_env/transformers.py
new file mode 100644
index 0000000000..9d97cd6dca
--- /dev/null
+++ b/lmdeploy/pytorch/check_env/transformers.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from packaging import version
+
+from .base import BaseChecker
+
+MIN_TRANSFORMERS_VERSION = '4.33.0'
+MAX_TRANSFORMERS_VERSION = '4.46.1'
+
+
+class TransformersChecker(BaseChecker):
+ """check transformers is available."""
+
+ def check(self):
+ """check."""
+ import transformers
+ logger = self.get_logger()
+ try:
+ trans_version = version.parse(transformers.__version__)
+ min_version = version.parse(MIN_TRANSFORMERS_VERSION)
+ max_version = version.parse(MAX_TRANSFORMERS_VERSION)
+ if trans_version < min_version or trans_version > max_version:
+ logger.warning('LMDeploy requires transformers version: '
+ f'[{MIN_TRANSFORMERS_VERSION} ~ '
+ f'{MAX_TRANSFORMERS_VERSION}], '
+ 'but found version: '
+ f'{transformers.__version__}')
+ except Exception as e:
+ self.log_and_exit(e, 'transformers',
+ 'transformers is not available.')
diff --git a/lmdeploy/pytorch/check_env/triton.py b/lmdeploy/pytorch/check_env/triton.py
new file mode 100644
index 0000000000..4cc58c5492
--- /dev/null
+++ b/lmdeploy/pytorch/check_env/triton.py
@@ -0,0 +1,60 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from packaging import version
+
+from .base import BaseChecker
+
+MAX_TRITON_VERSION = '3.1.0'
+MIN_TRITON_VERSION = '3.0.0'
+
+
+class TritonChecker(BaseChecker):
+ """check triton is available."""
+
+ def check_version(self):
+ """check version."""
+ logger = self.get_logger()
+
+ # version check
+ import triton
+ max_version = version.parse(MAX_TRITON_VERSION)
+ min_version = version.parse(MIN_TRITON_VERSION)
+ triton_version = version.parse(triton.__version__)
+
+ if triton_version > max_version:
+ logger.warning('PytorchEngine has not been tested on '
+ f'triton>{MAX_TRITON_VERSION}.')
+ if triton_version < min_version:
+ msg = (f'triton>={MIN_TRITON_VERSION} is required. '
+ f'Found triton=={triton_version}')
+ self.log_and_exit(mod_name='Triton', message=msg)
+
+ def check(self):
+ """check."""
+ logger = self.get_logger()
+
+ msg = (
+ 'Please ensure that your device is functioning properly with .\n' # noqa: E501
+ 'You can verify your environment by running '
+ '`python -m lmdeploy.pytorch.check_env.triton_custom_add`.')
+ try:
+ logger.debug('Checking environment.')
+ import torch
+
+ from .triton_custom_add import custom_add
+ a = torch.tensor([1, 2], device='cuda')
+ b = a.new_tensor([3, 4], device='cuda')
+ c = custom_add(a, b)
+ torch.testing.assert_close(c, a + b)
+ except RuntimeError as e:
+ ptxas_error = 'device kernel image is invalid'
+ if len(e.args) > 0 and ptxas_error in e.args[0]:
+ msg = (
+ 'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n' # noqa: E501
+ 'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209' # noqa: E501
+ ' or reinstall the driver.')
+ self.log_and_exit(e, 'Triton', msg)
+ except Exception as e:
+ self.log_and_exit(e, 'Triton', msg)
+
+ # version check
+ self.check_version()
diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py
index a9381890ee..b2cdc304b7 100644
--- a/lmdeploy/pytorch/config.py
+++ b/lmdeploy/pytorch/config.py
@@ -26,6 +26,10 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str):
return config
torch_dtype = getattr(config.hf_config, 'torch_dtype', None)
+ # deal with case when torch_dtype is not string but torch.dtype
+ if isinstance(torch_dtype, torch.dtype):
+ torch_dtype = str(torch_dtype).split('.')[1]
+
if torch_dtype is None:
_dtype = 'float16' if dtype == 'auto' else dtype
logger.warning('Model config does not have `torch_dtype`,'
@@ -37,8 +41,8 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str):
# change to user specified data type if it is not 'auto'
if dtype == 'auto':
torch_dtype = torch_dtype if torch_dtype in [
- torch.float16, torch.bfloat16
- ] else torch.float16
+ 'float16', 'bfloat16'
+ ] else 'float16'
else:
torch_dtype = dtype
config.dtype = eval(f'torch.{torch_dtype}')
@@ -77,6 +81,7 @@ class CacheConfig:
max_prefill_token_num: int = 4096
enable_prefix_caching: bool = False
quant_policy: Literal[0, 4, 8] = 0
+ device_type: str = 'cuda'
def __post_init__(self):
"""post init."""
@@ -103,7 +108,6 @@ class ModelConfig:
v_head_dim: int = None
sliding_window: int = -1
dtype: torch.dtype = torch.float16
- multi_query_attention: bool = False
vocab_size: int = 40000
hf_config: Any = None
cogvlm_style: bool = False
@@ -120,7 +124,8 @@ def get_head_size(self):
def from_pretrained(cls,
pretrained_model_name_or_path: str,
trust_remote_code: bool = True,
- dtype: str = 'auto'):
+ dtype: str = 'auto',
+ tp: int = 1):
"""Instantiate one of the configuration classes of the library from a
pretrained model configuration.
@@ -150,17 +155,21 @@ def from_pretrained(cls,
setattr(hf_config, 'architectures', ['MedusaModel'])
return cls.from_hf_config(hf_config,
pretrained_model_name_or_path,
- dtype=dtype)
+ dtype=dtype,
+ tp=tp)
@classmethod
def from_hf_config(cls,
hf_config: Any,
model_path: str = None,
- dtype: str = 'auto'):
+ dtype: str = 'auto',
+ tp: int = 1):
"""from huggingface config."""
from lmdeploy.pytorch.configurations import AutoModelConfigBuilder
- model_config = AutoModelConfigBuilder.build(hf_config, model_path)
+ model_config = AutoModelConfigBuilder.build(hf_config,
+ model_path,
+ tp=tp)
if model_config.k_head_dim is None:
assert model_config.head_dim is not None
@@ -169,6 +178,13 @@ def from_hf_config(cls,
assert model_config.head_dim is not None
model_config.v_head_dim = model_config.head_dim
+ # check for tp
+ assert model_config.num_attention_heads % tp == 0
+ if model_config.num_key_value_heads >= tp:
+ assert model_config.num_key_value_heads % tp == 0
+ else:
+ assert tp % model_config.num_key_value_heads == 0
+
# should after setting `hf_config` and `model_arch` attributes
model_config = _update_torch_dtype(model_config, dtype)
diff --git a/lmdeploy/pytorch/configurations/builder.py b/lmdeploy/pytorch/configurations/builder.py
index 89bf51ca46..bafa78ba02 100644
--- a/lmdeploy/pytorch/configurations/builder.py
+++ b/lmdeploy/pytorch/configurations/builder.py
@@ -27,7 +27,7 @@ def condition(cls, hf_config):
f'`condition` of {cls.__name__} not implemented.')
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
from .default import DefaultModelConfigBuilder
@@ -46,8 +46,21 @@ def build(cls, hf_config, model_path: str = None):
logger.debug(f'build model config with {valid_builder.__name__}')
- cfg = valid_builder.build(hf_config, model_path)
+ cfg = valid_builder.build(hf_config, model_path, **kwargs)
if cfg.hf_config is None:
cfg.hf_config = hf_config
return cfg
+
+ @classmethod
+ def update_num_kv_heads(cls, hf_config, tp, num_key_value_heads):
+ """update num kv heads."""
+ # update num_kv_heads for tp mode
+ if tp > 1 and tp > num_key_value_heads:
+ assert tp % num_key_value_heads == 0
+ n_replicate = tp // num_key_value_heads
+ hf_config.num_replicate_key_value_heads = n_replicate
+ num_key_value_heads = tp
+
+ hf_config.num_key_value_heads = num_key_value_heads
+ return num_key_value_heads
diff --git a/lmdeploy/pytorch/configurations/chatglm.py b/lmdeploy/pytorch/configurations/chatglm.py
index 7911c985d5..fbf4d48281 100644
--- a/lmdeploy/pytorch/configurations/chatglm.py
+++ b/lmdeploy/pytorch/configurations/chatglm.py
@@ -12,16 +12,27 @@ def condition(cls, hf_config):
return hf_config.model_type == 'chatglm'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
head_dim = hf_config.hidden_size // hf_config.num_attention_heads
bos_token_id = hf_config.bos_token_id
if bos_token_id is None:
bos_token_id = hf_config.pad_token_id
+
+ if hf_config.multi_query_attention:
+ num_key_value_heads = hf_config.multi_query_group_num
+ else:
+ num_key_value_heads = hf_config.num_attention_heads
+
+ tp = kwargs.get('tp', 1)
+ # update num_kv_heads for tp mode
+ num_key_value_heads = cls.update_num_kv_heads(hf_config, tp,
+ num_key_value_heads)
+
cfg = ModelConfig(hidden_size=hf_config.hidden_size,
num_layers=hf_config.num_layers,
num_attention_heads=hf_config.num_attention_heads,
- num_key_value_heads=hf_config.multi_query_group_num,
+ num_key_value_heads=num_key_value_heads,
bos_token_id=bos_token_id,
eos_token_id=hf_config.eos_token_id,
head_dim=head_dim,
diff --git a/lmdeploy/pytorch/configurations/cogvlm.py b/lmdeploy/pytorch/configurations/cogvlm.py
index b24d92d794..4736dfee69 100644
--- a/lmdeploy/pytorch/configurations/cogvlm.py
+++ b/lmdeploy/pytorch/configurations/cogvlm.py
@@ -12,12 +12,15 @@ def condition(cls, hf_config):
return model_arch == 'CogVLMForCausalLM'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
from lmdeploy.utils import is_bf16_supported
- cfg = DefaultModelConfigBuilder.build(hf_config)
if getattr(hf_config, 'num_multi_query_heads', None):
- cfg.num_key_value_heads = hf_config.num_multi_query_heads
+ hf_config.num_key_value_heads = hf_config.num_multi_query_heads
+ else:
+ hf_config.num_key_value_heads = hf_config.num_attention_heads
+
+ cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
cfg.cogvlm_style = True
torch_dtype = 'bfloat16' if is_bf16_supported() else 'float16'
hf_config.torch_dtype = torch_dtype
diff --git a/lmdeploy/pytorch/configurations/dbrx.py b/lmdeploy/pytorch/configurations/dbrx.py
index 2c8128a5a6..dcc1222b0d 100644
--- a/lmdeploy/pytorch/configurations/dbrx.py
+++ b/lmdeploy/pytorch/configurations/dbrx.py
@@ -12,7 +12,7 @@ def condition(cls, hf_config):
return hf_config.model_type == 'dbrx'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
hidden_size = hf_config.d_model
num_heads = hf_config.n_heads
diff --git a/lmdeploy/pytorch/configurations/deepseek_v2.py b/lmdeploy/pytorch/configurations/deepseek_v2.py
index 37aa4b0d69..bf06ff0c33 100644
--- a/lmdeploy/pytorch/configurations/deepseek_v2.py
+++ b/lmdeploy/pytorch/configurations/deepseek_v2.py
@@ -9,16 +9,22 @@ class DeepseekV2ModelConfigBuilder(AutoModelConfigBuilder):
@classmethod
def condition(cls, hf_config):
"""config."""
- return hf_config.model_type == 'deepseek_v2'
+ return hf_config.model_type in ['deepseek_v3', 'deepseek_v2']
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
head_dim = (hf_config.kv_lora_rank + hf_config.qk_rope_head_dim)
k_head_dim = head_dim
v_head_dim = 0
num_attention_heads = hf_config.num_attention_heads
+ # multi query attn
num_key_value_heads = 1
+ tp = kwargs.get('tp', 1)
+ # update num_kv_heads for tp mode
+ num_key_value_heads = cls.update_num_kv_heads(hf_config, tp,
+ num_key_value_heads)
+
return ModelConfig(hidden_size=hf_config.hidden_size,
num_layers=hf_config.num_hidden_layers,
num_attention_heads=num_attention_heads,
@@ -28,5 +34,4 @@ def build(cls, hf_config, model_path: str = None):
head_dim=head_dim,
k_head_dim=k_head_dim,
v_head_dim=v_head_dim,
- vocab_size=hf_config.vocab_size,
- multi_query_attention=True)
+ vocab_size=hf_config.vocab_size)
diff --git a/lmdeploy/pytorch/configurations/default.py b/lmdeploy/pytorch/configurations/default.py
index 1f84b810ea..d1337a241e 100644
--- a/lmdeploy/pytorch/configurations/default.py
+++ b/lmdeploy/pytorch/configurations/default.py
@@ -12,7 +12,7 @@ def condition(cls, hf_config):
return True
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
head_dim = hf_config.hidden_size // hf_config.num_attention_heads
num_attention_heads = hf_config.num_attention_heads
@@ -23,6 +23,11 @@ def build(cls, hf_config, model_path: str = None):
if use_sliding_window:
sliding_window = getattr(hf_config, 'sliding_window',
sliding_window) or -1
+ tp = kwargs.get('tp', 1)
+ # update num_kv_heads for tp mode
+ num_key_value_heads = cls.update_num_kv_heads(hf_config, tp,
+ num_key_value_heads)
+
return ModelConfig(
hidden_size=hf_config.hidden_size,
num_layers=hf_config.num_hidden_layers,
diff --git a/lmdeploy/pytorch/configurations/falcon.py b/lmdeploy/pytorch/configurations/falcon.py
index db4d00e397..a4c8d4d44f 100644
--- a/lmdeploy/pytorch/configurations/falcon.py
+++ b/lmdeploy/pytorch/configurations/falcon.py
@@ -12,7 +12,7 @@ def condition(cls, hf_config):
return hf_config.model_type == 'falcon'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build falcon."""
num_attention_heads = hf_config.num_attention_heads
if hf_config.new_decoder_architecture:
@@ -24,6 +24,12 @@ def build(cls, hf_config, model_path: str = None):
else:
# rw-1b, MHA
kv_head = num_attention_heads
+
+ tp = kwargs.get('tp', 1)
+ # update num_kv_heads for tp mode
+ kv_head = cls.update_num_kv_heads(hf_config, tp, kv_head)
+ hf_config.num_kv_heads = kv_head
+
head_dim = hf_config.hidden_size // num_attention_heads
return ModelConfig(
hidden_size=hf_config.hidden_size,
@@ -33,6 +39,5 @@ def build(cls, hf_config, model_path: str = None):
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
head_dim=head_dim,
- multi_query_attention=hf_config.multi_query,
vocab_size=hf_config.vocab_size,
)
diff --git a/lmdeploy/pytorch/configurations/gemma.py b/lmdeploy/pytorch/configurations/gemma.py
index 338eaee6d0..d49fdbd96c 100644
--- a/lmdeploy/pytorch/configurations/gemma.py
+++ b/lmdeploy/pytorch/configurations/gemma.py
@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from lmdeploy.pytorch.config import ModelConfig
-
from .builder import AutoModelConfigBuilder
+from .default import DefaultModelConfigBuilder
class GemmaModelConfigBuilder(AutoModelConfigBuilder):
@@ -12,13 +11,8 @@ def condition(cls, hf_config):
return hf_config.model_type in ['gemma', 'gemma2']
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build gemma."""
- return ModelConfig(hidden_size=hf_config.hidden_size,
- num_layers=hf_config.num_hidden_layers,
- num_attention_heads=hf_config.num_attention_heads,
- num_key_value_heads=hf_config.num_key_value_heads,
- bos_token_id=hf_config.bos_token_id,
- eos_token_id=hf_config.eos_token_id,
- head_dim=hf_config.head_dim,
- vocab_size=hf_config.vocab_size)
+ cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
+ cfg.head_dim = hf_config.head_dim
+ return cfg
diff --git a/lmdeploy/pytorch/configurations/internvl.py b/lmdeploy/pytorch/configurations/internvl.py
index 76b4187c5f..ffff0a0e15 100644
--- a/lmdeploy/pytorch/configurations/internvl.py
+++ b/lmdeploy/pytorch/configurations/internvl.py
@@ -11,8 +11,9 @@ def condition(cls, hf_config):
return hf_config.architectures[0] == 'InternVLChatModel'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build llava hf."""
- cfg = DefaultModelConfigBuilder.build(hf_config.llm_config)
+ cfg = DefaultModelConfigBuilder.build(hf_config.llm_config, model_path,
+ **kwargs)
cfg.hf_config = hf_config
return cfg
diff --git a/lmdeploy/pytorch/configurations/llava.py b/lmdeploy/pytorch/configurations/llava.py
deleted file mode 100644
index aaeeeeadfe..0000000000
--- a/lmdeploy/pytorch/configurations/llava.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .builder import AutoModelConfigBuilder
-from .default import DefaultModelConfigBuilder
-
-
-class LlavaModelConfigBuilder(AutoModelConfigBuilder):
-
- @classmethod
- def condition(cls, hf_config):
- """config."""
- return hf_config.architectures[0] in [
- 'LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM'
- ]
-
- @classmethod
- def build(cls, hf_config, model_path: str = None):
- """build."""
- arch = hf_config.architectures[0]
- if arch in ['LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM']:
- from llava.model.language_model.llava_llama import LlavaConfig
-
- # reload hf_config due to model_type='llava' is already
- # registered in transformers
- hf_config = LlavaConfig.from_pretrained(model_path)
- cfg = DefaultModelConfigBuilder.build(hf_config)
- return cfg
diff --git a/lmdeploy/pytorch/configurations/llava_hf.py b/lmdeploy/pytorch/configurations/llava_hf.py
index 4cc007e313..5334eaec25 100644
--- a/lmdeploy/pytorch/configurations/llava_hf.py
+++ b/lmdeploy/pytorch/configurations/llava_hf.py
@@ -15,7 +15,7 @@ def condition(cls, hf_config):
]
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build llava hf."""
text_config = hf_config.text_config
hidden_size = getattr(text_config, 'hidden_size', 4096)
diff --git a/lmdeploy/pytorch/configurations/medusa.py b/lmdeploy/pytorch/configurations/medusa.py
index 4935bc0e25..a4f705cd3f 100644
--- a/lmdeploy/pytorch/configurations/medusa.py
+++ b/lmdeploy/pytorch/configurations/medusa.py
@@ -12,7 +12,7 @@ def condition(cls, hf_config):
return hf_config.architectures[0] == 'MedusaModel'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
from transformers import AutoConfig
base_config = AutoConfig.from_pretrained(
diff --git a/lmdeploy/pytorch/configurations/minicpm3.py b/lmdeploy/pytorch/configurations/minicpm3.py
index 7cde51bd42..857673aab3 100644
--- a/lmdeploy/pytorch/configurations/minicpm3.py
+++ b/lmdeploy/pytorch/configurations/minicpm3.py
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from lmdeploy.pytorch.config import ModelConfig
from .builder import AutoModelConfigBuilder
+from .default import DefaultModelConfigBuilder
class MiniCPM3ModelConfigBuilder(AutoModelConfigBuilder):
@@ -12,21 +12,13 @@ def condition(cls, hf_config):
return hf_config.architectures[0] in ['MiniCPM3ForCausalLM']
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
head_dim = (hf_config.qk_nope_head_dim + hf_config.qk_rope_head_dim)
- k_head_dim = head_dim
- v_head_dim = head_dim
- num_attention_heads = hf_config.num_attention_heads
- num_key_value_heads = hf_config.num_key_value_heads
- return ModelConfig(hidden_size=hf_config.hidden_size,
- num_layers=hf_config.num_hidden_layers,
- num_attention_heads=num_attention_heads,
- num_key_value_heads=num_key_value_heads,
- bos_token_id=hf_config.bos_token_id,
- eos_token_id=hf_config.eos_token_id,
- head_dim=head_dim,
- k_head_dim=k_head_dim,
- v_head_dim=v_head_dim,
- vocab_size=hf_config.vocab_size,
- multi_query_attention=False)
+
+ cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
+ cfg.head_dim = head_dim
+ cfg.k_head_dim = head_dim
+ cfg.v_head_dim = head_dim
+
+ return cfg
diff --git a/lmdeploy/pytorch/configurations/mllama.py b/lmdeploy/pytorch/configurations/mllama.py
index 2383c92c50..e56e0fbed4 100644
--- a/lmdeploy/pytorch/configurations/mllama.py
+++ b/lmdeploy/pytorch/configurations/mllama.py
@@ -11,8 +11,9 @@ def condition(cls, hf_config):
return hf_config.architectures[0] == 'MllamaForConditionalGeneration'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build llava hf."""
- cfg = DefaultModelConfigBuilder.build(hf_config.text_config)
+ cfg = DefaultModelConfigBuilder.build(hf_config.text_config,
+ model_path, **kwargs)
cfg.hf_config = hf_config
return cfg
diff --git a/lmdeploy/pytorch/configurations/qwen.py b/lmdeploy/pytorch/configurations/qwen.py
index 05ac77c1d1..eda726de43 100644
--- a/lmdeploy/pytorch/configurations/qwen.py
+++ b/lmdeploy/pytorch/configurations/qwen.py
@@ -11,10 +11,10 @@ def condition(cls, hf_config):
return hf_config.model_type == 'qwen'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
from lmdeploy.utils import is_bf16_supported
- cfg = DefaultModelConfigBuilder.build(hf_config)
+ cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
if cfg.bos_token_id is None:
cfg.bos_token_id = 151644
if cfg.eos_token_id is None:
diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py
index e393adeed3..e3f97cfe46 100644
--- a/lmdeploy/pytorch/engine/cache_engine.py
+++ b/lmdeploy/pytorch/engine/cache_engine.py
@@ -44,7 +44,13 @@ def __init__(
self.num_layers = model_config.num_layers
self.kv_cache_dtype = model_config.dtype
if cache_config.quant_policy > 0:
- self.kv_cache_dtype = torch.uint8
+ if self.cache_config.device_type in ['cuda']:
+ self.kv_cache_dtype = torch.uint8
+ elif self.cache_config.device_type in ['ascend', 'npu']:
+ self.kv_cache_dtype = torch.int8
+ else:
+ raise ValueError(
+ f'unsupported device_type {self.cache_config.device_type}')
# Initialize the cache.
self.local_gpu_cache = self.allocate_gpu_cache()
@@ -92,7 +98,7 @@ def _get_key_block_shape_impl(cls,
attn_backend = get_backend()
dtype = model_config.dtype
num_heads = model_config.num_key_value_heads
- if local and not model_config.multi_query_attention:
+ if local:
assert num_heads % world_size == 0, \
f'num_heads: {num_heads}, world_size: {world_size}'
num_heads = num_heads // world_size
@@ -115,7 +121,7 @@ def _get_value_block_shape_impl(cls,
attn_backend = get_backend()
dtype = model_config.dtype
num_heads = model_config.num_key_value_heads
- if local and not model_config.multi_query_attention:
+ if local:
assert num_heads % world_size == 0, \
f'num_heads: {num_heads}, world_size: {world_size}'
num_heads = num_heads // world_size
@@ -202,7 +208,7 @@ def allocate_gpu_cache(self):
def allocate_cpu_cache(self):
"""allocate caches on Host."""
- caches = self._allocate_cache(self.num_gpu_blocks, 'cpu')
+ caches = self._allocate_cache(self.num_cpu_blocks, 'cpu')
self.full_cpu_cache = caches
self.local_cpu_cache = list(zip(*caches))
diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py
index 263ef784ee..19c8017b7a 100644
--- a/lmdeploy/pytorch/engine/engine.py
+++ b/lmdeploy/pytorch/engine/engine.py
@@ -8,19 +8,17 @@
import numpy as np
import torch
-from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig,
- ResponseType)
+from lmdeploy.messages import PytorchEngineConfig, ResponseType
from lmdeploy.utils import (get_logger, get_max_batch_size, get_model,
logging_timer)
from ..adapter.adapter import AdapterManager
-from ..check_env import check_adapters, check_env, check_model
from ..config import BackendConfig, CacheConfig, SchedulerConfig
from ..devices import DeviceContext, get_device_manager
-from ..messages import (InputEmbeddingRangeType, InputEmbeddingType,
- MessageStatus, SchedulerSequence)
-from ..model_inputs import ModelInputs, MRopeModelInputs, VisionModelInputs
+from ..messages import MessageStatus, SchedulerSequence
+from ..model_inputs import ModelInputs, VisionModelInputs
from ..paging import Scheduler
+from .engine_checker import EngineChecker
from .logits_process import FusedLogitsProcessor, SamplingInputs
from .model_agent import build_model_agent
from .request import Request, RequestManager, RequestType, Response
@@ -32,24 +30,13 @@
_EMPTY_TOKEN = np.empty((0, ), dtype=np.int64)
-def _raise_exception_on_finish(task: asyncio.Task) -> None:
- """raise exception on finish."""
- try:
- task.result()
- except asyncio.CancelledError:
- return
- except Exception as e:
- raise e
-
-
@dataclass
class InferOutput:
"""The output of the model inference."""
session_id: int
+ resp: Response
token_ids: List[int]
- sender_id: int
- req_id: int
meta: Any = None
finish: bool = False
logits: torch.Tensor = None
@@ -78,6 +65,40 @@ def _check_finish(scheduler: Scheduler, current_iter: int):
return False
+def _build_scheduler_config(engine_config: PytorchEngineConfig):
+ """build scheduler config."""
+ scheduler_config = SchedulerConfig(
+ max_batches=engine_config.max_batch_size,
+ max_session_len=engine_config.session_len,
+ prefill_interval=engine_config.prefill_interval)
+ return scheduler_config
+
+
+def _build_cache_config(engine_config: PytorchEngineConfig):
+ """build cache config."""
+ cache_config = CacheConfig(
+ max_batches=engine_config.max_batch_size,
+ block_size=engine_config.block_size,
+ num_cpu_blocks=engine_config.num_cpu_blocks,
+ num_gpu_blocks=engine_config.num_gpu_blocks,
+ cache_max_entry_count=engine_config.cache_max_entry_count,
+ max_prefill_token_num=engine_config.max_prefill_token_num,
+ enable_prefix_caching=engine_config.enable_prefix_caching,
+ quant_policy=engine_config.quant_policy,
+ device_type=engine_config.device_type,
+ )
+ return cache_config
+
+
+def _build_backend_config(engine_config: PytorchEngineConfig):
+ """build backend config."""
+ backend_config = BackendConfig(
+ eager_mode=engine_config.eager_mode,
+ device_type=engine_config.device_type,
+ )
+ return backend_config
+
+
class Engine:
"""The inference engine of lmdeploy pytorch.
@@ -97,43 +118,23 @@ def __init__(self,
engine_config = PytorchEngineConfig()
else:
engine_config = copy.deepcopy(engine_config)
- check_env(engine_config.device_type)
- check_model(model_path, trust_remote_code, engine_config.dtype,
- engine_config.device_type)
if engine_config.max_batch_size is None:
engine_config.max_batch_size = get_max_batch_size(
engine_config.device_type)
- adapters = engine_config.adapters
- if adapters is not None:
- check_adapters(list(adapters.values()))
- assert engine_config.max_batch_size > 0, 'max_batch_size should be' \
- f' greater than 0, but got {engine_config.max_batch_size}'
- assert engine_config.dtype in ['auto', 'float16', 'bfloat16'], \
- f'unsupported specified data type {engine_config.dtype}'
+ checker = EngineChecker(model_path=model_path,
+ engine_config=engine_config,
+ trust_remote_code=trust_remote_code,
+ logger=logger)
+ checker.handle()
+
+ adapters = engine_config.adapters
self.engine_config = engine_config
self.tp = engine_config.tp
self.device_context = DeviceContext(
device_type=engine_config.device_type)
- scheduler_config = SchedulerConfig(
- max_batches=engine_config.max_batch_size,
- max_session_len=engine_config.session_len,
- prefill_interval=engine_config.prefill_interval)
-
- # block_size = 1 to enable unified paging
- cache_config = CacheConfig(
- max_batches=engine_config.max_batch_size,
- block_size=engine_config.block_size,
- num_cpu_blocks=engine_config.num_cpu_blocks,
- num_gpu_blocks=engine_config.num_gpu_blocks,
- cache_max_entry_count=engine_config.cache_max_entry_count,
- max_prefill_token_num=engine_config.max_prefill_token_num,
- enable_prefix_caching=engine_config.enable_prefix_caching,
- quant_policy=engine_config.quant_policy,
- )
-
if not os.path.exists(model_path):
model_path = get_model(model_path, engine_config.download_dir,
engine_config.revision)
@@ -142,10 +143,9 @@ def __init__(self,
if adapters is not None and len(adapters) > 0:
adapters = self._download_adapters(adapters, engine_config)
- backend_config = BackendConfig(
- eager_mode=engine_config.eager_mode,
- device_type=engine_config.device_type,
- )
+ scheduler_config = _build_scheduler_config(engine_config)
+ cache_config = _build_cache_config(engine_config)
+ backend_config = _build_backend_config(engine_config)
with get_device_manager().context(self.device_context):
self.model_agent = build_model_agent(
@@ -159,6 +159,8 @@ def __init__(self,
dtype=engine_config.dtype,
custom_module_map=engine_config.custom_module_map)
+ self.input_processor = self.model_agent.get_input_processor()
+
cache_config = self.model_agent.cache_config
self.adapter_manager = self._build_adapter_manager(adapters)
self.scheduler = Scheduler(scheduler_config, cache_config)
@@ -174,7 +176,6 @@ def __init__(self,
# create main thread
self._start_loop()
self._create_buffers()
- self.engine_instance = self.create_instance()
self._output_stream = torch.cuda.Stream()
@classmethod
@@ -244,7 +245,7 @@ def _build_adapter_manager(self, adapters):
def _bind_request_manager(self):
"""bind request manager."""
- req_manager = RequestManager(self.engine_config.thread_safe)
+ req_manager = RequestManager()
req_manager.bind_func(RequestType.ADD_SESSION, self._on_add_session)
req_manager.bind_func(RequestType.STOP_SESSION, self._on_stop_session)
req_manager.bind_func(RequestType.END_SESSION, self._on_end_session)
@@ -256,18 +257,15 @@ def _start_loop(self):
return self.req_manager.start_loop(self.async_loop)
def _response(self,
+ resp: Response,
resp_type: ResponseType,
- sender_id: int,
- req_id: int,
data: Any = None,
err_msg: str = ''):
"""response."""
- self.req_manager.response(
- Response(type=resp_type,
- sender_id=sender_id,
- req_id=req_id,
- data=data,
- err_msg=err_msg))
+ resp.type = resp_type
+ resp.data = data
+ resp.err_msg = err_msg
+ self.req_manager.response(resp)
def _get_max_session_len(self):
"""get max session len."""
@@ -293,7 +291,7 @@ def _on_add_session(self, reqs: Request, **kwargs):
self.scheduler.add_session(session_id)
resp_type = ResponseType.SUCCESS
if resp:
- self._response(resp_type, req.sender_id, req.req_id)
+ self._response(req.resp, resp_type)
def _on_stop_session(self, reqs: Request, **kwargs):
"""on stop session callback."""
@@ -305,7 +303,7 @@ def _on_stop_session(self, reqs: Request, **kwargs):
self.scheduler.stop_session(session_id)
resp_type = ResponseType.SUCCESS
if resp:
- self._response(resp_type, req.sender_id, req.req_id)
+ self._response(req.resp, resp_type)
def _on_end_session(self, reqs: Request, **kwargs):
"""on end session callback."""
@@ -317,10 +315,35 @@ def _on_end_session(self, reqs: Request, **kwargs):
self.scheduler.end_session(session_id)
resp_type = ResponseType.SUCCESS
if resp:
- self._response(resp_type, req.sender_id, req.req_id)
+ self._response(req.resp, resp_type)
def _on_add_message(self, reqs: Request, **kwargs):
"""on add message callback."""
+ for req in reqs:
+ req_data = req.data
+ if req_data.get('input_multimodals', None) is None:
+ continue
+ elif self.input_processor is None:
+ logger.warning('Do not support Multimodal inputs.')
+ continue
+ input_ids = req_data['token_ids']
+ input_multimodals = req_data['input_multimodals']
+ if len(input_multimodals) == 0:
+ req_data['input_multimodals'] = None
+ continue
+ result = self.input_processor.preprocess_input(
+ input_ids, input_multimodals)
+
+ input_ids = result.input_ids
+ input_multimodals = result.input_multimodals
+
+ req_data['token_ids'] = input_ids
+ req_data['input_multimodals'] = input_multimodals
+
+ if len(reqs) > 0:
+ self._add_message(reqs)
+
+ def _add_message(self, reqs):
def __update_bad_words(msg):
"""update bad words."""
@@ -346,8 +369,7 @@ def __update_max_new_tokens(msg):
for req in reqs:
session_id = req.data['session_id']
if session_id not in self.scheduler.sessions:
- self._response(ResponseType.SESSION_NOT_EXIST, req.sender_id,
- req.req_id)
+ self._response(req.resp, ResponseType.SESSION_NOT_EXIST)
continue
session_id = req.data['session_id']
sess = self.scheduler.sessions[session_id]
@@ -360,11 +382,8 @@ def __update_max_new_tokens(msg):
sampling_param=req.data['sampling_param'],
adapter_name=req.data['adapter_name'],
return_logits=req.data.get('return_logits', False),
+ multimodals=req.data.get('input_multimodals'),
input_embeddings=req.data.get('input_embeddings'),
- mrope_position_ids=req.data.get('mrope_position_ids'),
- mrope_position_delta=req.data.get('mrope_position_delta'),
- cross_attention_states=req.data.get(
- 'cross_attention_states'),
)
msg = next(iter(sess.sequences.values()))
__update_bad_words(msg)
@@ -372,9 +391,11 @@ def __update_max_new_tokens(msg):
self.scheduler.add_sequence(msg)
else:
msg = next(iter(sess.sequences.values()))
- msg.update_token_ids(req.data['token_ids'],
- req.data.get('input_embeddings'),
- req.data.get('cross_attention_states'))
+ msg.update_token_ids(
+ req.data['token_ids'],
+ multimodals=req.data.get('input_multimodals'),
+ embeddings=req.data.get('input_embeddings'),
+ )
msg.num_new_tokens = 0
msg.sampling_param = req.data['sampling_param']
msg.return_logits = req.data.get('return_logits', False)
@@ -382,8 +403,7 @@ def __update_max_new_tokens(msg):
__update_bad_words(msg)
__update_max_new_tokens(msg)
- msg.sender_id = req.sender_id
- msg.req_id = req.req_id
+ msg.resp = req.resp
@property
def model_config(self):
@@ -420,7 +440,6 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool):
seq_length = self._seq_length_buf[:batch_size]
max_q_seq_length = seq_length.max().item()
- # TODO: get block offsets is slow when block_size = 1
block_offsets = self.scheduler.get_block_tables(messages)
block_offsets = _tensorlize_block_offsets(block_offsets)
@@ -438,13 +457,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool):
num_ignored_history = [msg.num_ignored_history for msg in messages]
num_ignored_history = torch.tensor(num_ignored_history)
- def __get_cogvlm_image_info():
- """Get cogvlm history image info for position ids."""
- history_image_nums = torch.LongTensor(
- [msg.history_image_num for msg in messages])
- history_image_token_lengths = torch.LongTensor(
- [msg.history_image_token_len for msg in messages])
- return history_image_nums, history_image_token_lengths
+ model_metas = [msg.model_meta for msg in messages]
def __get_vlm_embeddings():
"""get vlm input embeddings and indexings."""
@@ -469,25 +482,9 @@ def __get_vlm_embeddings():
return (input_embeddings, input_embedding_indexing,
input_embedding_ranges)
- def __get_mrope_inputs():
- """get multimodal rotary position inputs."""
- position_ids = [msg.mrope_position_ids for msg in messages]
- deltas = [msg.mrope_position_delta for msg in messages]
- return MRopeModelInputs(position_ids=position_ids, deltas=deltas)
-
# for inputs with embeddings
history_image_nums = None
history_image_token_lengths = None
- # only for cogvlm
- if self.model_config.cogvlm_style:
- (history_image_nums,
- history_image_token_lengths) = __get_cogvlm_image_info()
- # only for qwen2_vl
- mrope_inputs = None
- has_mrope_params = any(
- [msg.mrope_position_ids is not None for msg in messages])
- if has_mrope_params:
- mrope_inputs = __get_mrope_inputs()
input_embeddings = None
input_embedding_indexing = None
@@ -498,25 +495,40 @@ def __get_mrope_inputs():
(input_embeddings, input_embedding_indexing,
input_embedding_ranges) = __get_vlm_embeddings()
+ input_multimodals = None
+ has_multimodal = any(
+ [not msg.history_multimodals.empty() for msg in messages])
+ if has_multimodal:
+ has_multimodal = False
+ input_multimodals = [
+ msg.get_input_multimodals() for msg in messages
+ ]
+ for input_mm in input_multimodals:
+ for val in input_mm.values():
+ if len(val) > 0:
+ has_multimodal = True
+ break
+ if has_multimodal:
+ break
+
vision_embedding_inputs = None
- if has_embedding or history_image_nums is not None:
+ if has_embedding or has_multimodal or history_image_nums is not None:
vision_embedding_inputs = VisionModelInputs(
history_lengths=history_lengths,
history_image_nums=history_image_nums,
history_image_token_lengths=history_image_token_lengths,
input_embeddings=input_embeddings,
input_embedding_indexing=input_embedding_indexing,
- input_embedding_ranges=input_embedding_ranges)
-
- # only for mllama
- cross_attention_states = None
- history_cross_kv_seqlens = None
- if any([msg.cross_attention_states is not None for msg in messages]):
- cross_attention_states = [
- msg.cross_attention_states for msg in messages
- ]
- history_cross_kv_seqlens = torch.tensor(
- [msg.history_cross_kv_seqlens for msg in messages])
+ input_embedding_ranges=input_embedding_ranges,
+ input_multimodals=input_multimodals)
+
+ # cross
+ cross_length = torch.tensor([msg.num_cross for msg in messages])
+ history_cross_length = torch.tensor(
+ [msg.num_history_cross for msg in messages])
+ if (cross_length + history_cross_length).max().item() == 0:
+ cross_length = None
+ history_cross_length = None
return ModelInputs(
input_ids=input_ids,
@@ -527,9 +539,9 @@ def __get_mrope_inputs():
num_ignored_history=num_ignored_history,
local_adapter_ids=local_adapter_ids,
vision_inputs=vision_embedding_inputs,
- mrope_inputs=mrope_inputs,
- cross_attention_states=cross_attention_states,
- history_cross_kv_seqlens=history_cross_kv_seqlens,
+ cross_length=cross_length,
+ history_cross_length=history_cross_length,
+ model_metas=model_metas,
)
def _batch_stopping_criteria(self, token_ids: torch.Tensor,
@@ -550,11 +562,12 @@ def _batch_stopping_criteria(self, token_ids: torch.Tensor,
return stopped, num_appendable_ids
@logging_timer('SamplingLogits', logger)
- def async_sampling_logits(self, logits: torch.Tensor,
- all_ids: torch.Tensor,
- guided_input_ids: torch.Tensor,
- sampling_inputs: SamplingInputs,
- inputs: ModelInputs, ignore_eos: torch.Tensor):
+ async def async_sampling_logits(self, logits: torch.Tensor,
+ all_ids: torch.Tensor,
+ guided_input_ids: torch.Tensor,
+ sampling_inputs: SamplingInputs,
+ inputs: ModelInputs,
+ ignore_eos: torch.Tensor):
"""sampling logits."""
def __get_last_logits():
@@ -569,7 +582,8 @@ def __get_last_logits():
split_logits = __get_last_logits()
logits_processor = FusedLogitsProcessor(sampling_inputs, ignore_eos,
self.tokenizer.model.model)
- logits = logits_processor(all_ids, guided_input_ids, split_logits)
+ logits = await logits_processor(all_ids, guided_input_ids,
+ split_logits)
next_token_ids = logits_processor.sampling(logits)
return next_token_ids
@@ -587,20 +601,24 @@ def extract_tokens(self, token_ids, eos_token_ids):
@logging_timer('UpdateRunning', logger)
def update_running(self, running: SeqList, next_token_ids: torch.Tensor,
- stopped: torch.Tensor):
+ stopped: torch.Tensor, model_metas: List[Dict[str,
+ Any]]):
"""update scheduler."""
+ if model_metas is None:
+ model_metas = [None] * len(running)
next_token_ids = next_token_ids.numpy()
- eos_token_id = self.model_config.eos_token_id
- for token, msg, stop in zip(next_token_ids, running, stopped):
+ for token, msg, stop, model_meta in zip(next_token_ids, running,
+ stopped, model_metas):
if msg.status != MessageStatus.RUNNING:
continue
+ eos_token_id = self.model_config.eos_token_id
update_token, eos_stop = self.extract_tokens(token, eos_token_id)
stop = stop or eos_stop
if stop:
update_token = _EMPTY_TOKEN
else:
msg.num_new_tokens += len(update_token)
- msg.update_token_ids(update_token)
+ msg.update_token_ids(update_token, model_meta=model_meta)
if stop:
msg.status = MessageStatus.STOPPED
@@ -666,12 +684,14 @@ async def __long_context_single_forward(inputs):
batch_size = seq_len.size(0)
assert batch_size == 1
- new_inputs = inputs.split(max_prefill_token_num,
- self.cache_config.block_size)
+ new_inputs = inputs.split(max_prefill_token_num)
+ model_metas = new_inputs[0].model_metas
output_gather = _OutputGather(max_seq_len)
for inp in new_inputs:
+ inp.model_metas = model_metas
tmp_out = await __forward(inp)
+ model_metas = tmp_out.get('model_metas')
output_gather.gather(tmp_out)
tmp_out.pop('hidden_states', None)
tmp_out['hidden_states'] = output_gather.get_output()
@@ -703,33 +723,12 @@ async def __long_context_single_forward(inputs):
ret['logits'] = logits
return ret
- def _make_infer_outputs(self, next_token_ids: torch.LongTensor,
- logits: torch.Tensor, stopped: torch.Tensor,
- event: torch.cuda.Event):
+ async def _make_infer_outputs(self, next_token_ids: torch.LongTensor,
+ logits: torch.Tensor, stopped: torch.Tensor,
+ model_metas: List[Dict[str, Any]],
+ event: torch.cuda.Event):
"""make infer output."""
- def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence,
- stopped: bool):
- """check if output is necessary."""
- if isinstance(token, list):
- idx = len(token)
- for i, t in enumerate(token):
- if t == -1:
- idx = i
- break
- if stopped:
- idx = min(
- idx,
- msg.sampling_param.max_new_tokens - msg.num_new_tokens)
- token = token[:idx]
- else:
- if stopped:
- return []
- if token in msg.sampling_param.stop_words:
- return []
- token = [token]
- return token
-
def __get_q_start_loc():
inputs = self._inputs
seq_length = inputs.seq_length
@@ -739,15 +738,16 @@ def __get_q_start_loc():
else:
return seq_length.cumsum(0) - seq_length
+ while not event.query():
+ await asyncio.sleep(0.001)
with torch.cuda.stream(self._output_stream):
- event.wait()
next_token_ids = next_token_ids.cpu()
stopped = stopped.cpu()
running = self._running
is_run = [seq.status == MessageStatus.RUNNING for seq in running]
stopped = stopped.tolist()
- self.update_running(running, next_token_ids, stopped)
+ self.update_running(running, next_token_ids, stopped, model_metas)
# generate output
next_token_ids = next_token_ids.tolist()
@@ -756,16 +756,15 @@ def __get_q_start_loc():
for idx, msg in enumerate(running):
if not is_run[idx]:
continue
- token_ids = __get_out_token_ids(next_token_ids[idx], msg,
- stopped[idx])
+ token_ids = msg.all_ids[-msg.num_new_tokens:]
finish = msg.status == MessageStatus.STOPPED
if not finish and len(token_ids) == 0:
continue
session_id = msg.session_id
+ resp = msg.resp
out = InferOutput(
session_id=session_id,
- sender_id=msg.sender_id,
- req_id=msg.req_id,
+ resp=resp,
finish=finish,
token_ids=token_ids,
)
@@ -805,8 +804,7 @@ def __update_inputs(next_token_ids):
logger.debug(': '
f'batch_size={inputs.seq_length.size(0)} '
f'num_tokens={inputs.input_ids.size(-1)}')
- if self.gpu_count == 1:
- inputs = inputs.to_device('cuda')
+ inputs = inputs.to_device('cuda')
is_decoding = inputs.is_decoding
if all_ids is not None:
all_ids = all_ids.cuda()
@@ -827,7 +825,7 @@ def __update_inputs(next_token_ids):
logits = logits[0] # [bs, seq, prob] -> [seq, prob]
# sampling
- next_token_ids = self.async_sampling_logits(
+ next_token_ids = await self.async_sampling_logits(
logits, all_ids, guided_input_ids, sampling_inputs, inputs,
num_ignore_eos > 0)
num_ignore_eos = num_ignore_eos - 1
@@ -873,13 +871,16 @@ def __update_inputs(next_token_ids):
next_token_ids, sampling_inputs.stop_words, num_appendable_ids)
# send output
+ model_metas = output.get('model_metas')
finish = (idx == loop_count - 1)
finish = finish or _check_finish(self.scheduler, idx)
event = torch.cuda.Event()
event.record()
- output = (next_token_ids, logits, stopped, event)
+ output = (next_token_ids, logits, stopped, model_metas, event)
output_que.put_nowait((finish, output))
+ inputs.model_metas = model_metas
+
if finish:
break
@@ -889,9 +890,28 @@ def __update_inputs(next_token_ids):
swap_out_map = dict()
__update_inputs(next_token_ids)
+ def _set_has_runable_event(self, has_runable_event: asyncio.Event):
+ """set has runable event."""
+ if self.scheduler.has_unfinished():
+ has_runable_event.set()
+ else:
+ has_runable_event.clear()
+
+ @torch.inference_mode()
+ async def _async_loop_preprocess_message(self,
+ forward_event: asyncio.Event,
+ has_runable_event: asyncio.Event):
+ """preprocess msg."""
+ while True:
+ if self.scheduler.has_unfinished():
+ await forward_event.wait()
+ await self.req_manager.step()
+ self._set_has_runable_event(has_runable_event)
+
@torch.inference_mode()
async def _async_loop_background(self, in_que: asyncio.Queue,
- out_que: asyncio.Queue):
+ out_que: asyncio.Queue,
+ forward_event: asyncio.Event):
"""async loop background."""
def __gather_all_ids(seqs: SeqList, sampling_inputs: SamplingInputs):
@@ -952,66 +972,52 @@ def __need_logits(seqs: SeqList):
while True:
is_prefill, scheduler_output = await in_que.get()
- try:
- running = scheduler_output.running
- swap_in_map = scheduler_output.swap_in_map
- swap_out_map = scheduler_output.swap_out_map
- prefill_interval = self.scheduler_config.prefill_interval
- loop_count = 1 if is_prefill else (prefill_interval - 1)
- assert len(running) > 0
-
- # create inputs
- inputs = self.create_model_inputs(running, is_prefill)
- sampling_inputs = SamplingInputs.from_sampling_params(running)
- all_ids = __gather_all_ids(running, sampling_inputs)
- guided_input_ids = __gather_guided_input_ids(
- running, sampling_inputs)
- num_appendable_ids = __get_num_appendable_ids(running)
- num_ignore_eos = __get_num_ignore_eos(running)
- return_logits = __need_logits(running)
-
- self._running = running
- self._inputs = inputs
-
- await self._async_step_background(
- inputs=inputs,
- swap_in_map=swap_in_map,
- swap_out_map=swap_out_map,
- all_ids=all_ids,
- guided_input_ids=guided_input_ids,
- sampling_inputs=sampling_inputs,
- num_appendable_ids=num_appendable_ids,
- num_ignore_eos=num_ignore_eos,
- loop_count=loop_count,
- return_logits=return_logits,
- output_que=out_que,
- )
- except Exception as e:
- out_que.put_nowait((True, e))
- finally:
- in_que.task_done()
-
- @torch.inference_mode()
- async def _async_loop(self):
- """Main loop of the engine.
+ running = scheduler_output.running
+ swap_in_map = scheduler_output.swap_in_map
+ swap_out_map = scheduler_output.swap_out_map
+ prefill_interval = self.scheduler_config.prefill_interval
+ loop_count = 1 if is_prefill else (prefill_interval - 1)
+ assert len(running) > 0
+
+ # create inputs
+ inputs = self.create_model_inputs(running, is_prefill)
+ sampling_inputs = SamplingInputs.from_sampling_params(running)
+ all_ids = __gather_all_ids(running, sampling_inputs)
+ guided_input_ids = __gather_guided_input_ids(
+ running, sampling_inputs)
+ num_appendable_ids = __get_num_appendable_ids(running)
+ num_ignore_eos = __get_num_ignore_eos(running)
+ return_logits = __need_logits(running)
+
+ self._running = running
+ self._inputs = inputs
+
+ forward_event.clear()
+ await self._async_step_background(
+ inputs=inputs,
+ swap_in_map=swap_in_map,
+ swap_out_map=swap_out_map,
+ all_ids=all_ids,
+ guided_input_ids=guided_input_ids,
+ sampling_inputs=sampling_inputs,
+ num_appendable_ids=num_appendable_ids,
+ num_ignore_eos=num_ignore_eos,
+ loop_count=loop_count,
+ return_logits=return_logits,
+ output_que=out_que,
+ )
+ forward_event.set()
- Each engine instance would communicate with the engine by queue.
- """
- prefill_interval = self.scheduler_config.prefill_interval
- in_que = asyncio.Queue()
- out_que = asyncio.Queue()
- loop_background = asyncio.get_event_loop().create_task(
- self._async_loop_background(in_que, out_que),
- name='MainLoopBackground')
- loop_background.add_done_callback(_raise_exception_on_finish)
+ async def _async_send_responses(self, que: asyncio.Queue,
+ forward_event: asyncio.Event):
+ """send responses."""
def __send_resp(out: InferOutput):
"""send response."""
resp_type = (ResponseType.FINISH
if out.finish else ResponseType.SUCCESS)
- self._response(resp_type,
- sender_id=out.sender_id,
- req_id=out.req_id,
+ self._response(out.resp,
+ resp_type,
data=dict(token_ids=out.token_ids,
logits=out.logits))
@@ -1020,9 +1026,89 @@ def __send_resps(step_outputs: Dict[int, InferOutput]):
for out in step_outputs.values():
__send_resp(out)
+ while True:
+ resps = await que.get()
+ if self.scheduler.has_unfinished():
+ await forward_event.wait()
+ __send_resps(resps)
+
+ @staticmethod
+ def _add_loop_tasks_done_callback(tasks: List[asyncio.Task]):
+ """add loop tasks done callback."""
+
+ def __task_callback(task: asyncio.Task) -> None:
+ """raise exception on finish."""
+ task_name = task.get_name()
+ try:
+ task.result()
+ except asyncio.CancelledError:
+ logger.debug(f'Task <{task_name}> cancelled.')
+ return
+ except Exception:
+ logger.exception(f'Task <{task_name}> failed')
+ for task in tasks:
+ if not task.cancelled():
+ task.cancel()
+
+ for task in tasks:
+ task.add_done_callback(__task_callback)
+
+ @torch.inference_mode()
+ async def _async_loop(self):
+ """Main loop of the engine.
+
+ Each engine instance would communicate with the engine by queue.
+ """
+ event_loop = asyncio.get_event_loop()
+ prefill_interval = self.scheduler_config.prefill_interval
+
+ # forward task
+ in_que = asyncio.Queue()
+ out_que = asyncio.Queue()
+ forward_event = asyncio.Event()
+ forward_event.set()
+ loop_background = event_loop.create_task(self._async_loop_background(
+ in_que, out_que, forward_event),
+ name='MainLoopBackground')
+
+ # preprocess task
+ has_runable_event = asyncio.Event()
+ loop_msg_proc = event_loop.create_task(
+ self._async_loop_preprocess_message(forward_event,
+ has_runable_event),
+ name='MainLoopPreprocessMessage')
+
+ # response task
+ resp_que = asyncio.Queue()
+ loop_send_resp = event_loop.create_task(self._async_send_responses(
+ resp_que, forward_event),
+ name='MainLoopResponse')
+
+ loop_main = asyncio.current_task()
+ loop_tasks: List[asyncio.Task] = [
+ loop_main, loop_background, loop_msg_proc, loop_send_resp
+ ]
+ self._add_loop_tasks_done_callback(loop_tasks)
+
+ def __do_prefill():
+ # decoding if no waiting
+ if not self.scheduler.has_waiting():
+ return False
+ num_running = self.scheduler.num_running()
+ num_waiting = self.scheduler.num_waiting()
+ max_batches = self.scheduler_config.max_batches
+ # prefill if too much waiting
+ if num_waiting >= 4:
+ return True
+ # prefill if no enough running
+ if num_running < max_batches * 0.5:
+ return True
+ # decoding
+ return False
+
async def __step():
"""step decoding."""
- prefill = self.scheduler.has_waiting()
+ prefill = __do_prefill()
schedule_output = self.scheduler.schedule(
is_prefill=prefill, prealloc_size=prefill_interval)
# schedule decoding if no valid prefill reqs.
@@ -1036,29 +1122,13 @@ async def __step():
in_que.put_nowait((prefill, schedule_output))
finish = False
while not finish:
- if self.req_manager.has_requests():
- self.req_manager.step()
finish, out = await out_que.get()
- try:
- if isinstance(out, Exception):
- raise out
- next_token_ids, logits, stopped, event = out
- step_outputs = self._make_infer_outputs(
- next_token_ids, logits, stopped, event)
- __send_resps(step_outputs)
- except Exception as e:
- raise e
- finally:
- out_que.task_done()
+ step_outputs = await self._make_infer_outputs(*out)
+ self._set_has_runable_event(has_runable_event)
+ resp_que.put_nowait(step_outputs)
while True:
- if self.req_manager.has_requests():
- self.req_manager.step()
-
- if not self.scheduler.has_unfinished():
- await asyncio.sleep(0.01)
- continue
-
+ await has_runable_event.wait()
await __step()
async def async_loop(self):
@@ -1077,78 +1147,3 @@ def create_instance(self, cuda_stream_id=0):
"""
from .engine_instance import EngineInstance
return EngineInstance(self)
-
- async def async_batched_infer(
- self,
- session_ids: List[int],
- token_ids: List[List[int]] = None,
- gen_config: GenerationConfig = None,
- adapter_names: List[str] = None,
- keep_cache: bool = False,
- input_embeddings: List[InputEmbeddingType] = None,
- input_embedding_ranges: List[InputEmbeddingRangeType] = None):
- """Send inference request.
-
- Args:
- session_ids (List[int]): The session id.
- token_ids (List[int]): The input token ids.
- gen_config (GenerationConfig): The sampling parameters.
- adapter_names (List[str]): The name of the adapters.
- keep_cache (bool): Keep kv cache after infer.
-
- Returns:
- int: Error flags. 0 if success.
- List[int]: The streaming output tokens.
- int: The number of the output tokens.
- """
- return await self.engine_instance.async_batched_infer(
- session_ids=session_ids,
- token_ids=token_ids,
- gen_config=gen_config,
- adapter_names=adapter_names,
- input_embeddings=input_embeddings,
- input_embedding_ranges=input_embedding_ranges,
- keep_cache=keep_cache)
-
- def batched_infer(
- self,
- session_ids: List[int],
- token_ids: List[List[int]] = None,
- gen_config: GenerationConfig = None,
- adapter_names: List[str] = None,
- keep_cache: bool = False,
- input_embeddings: List[InputEmbeddingType] = None,
- input_embedding_ranges: List[InputEmbeddingRangeType] = None):
- """batched infer."""
- return self.engine_instance.batched_infer(
- session_ids=session_ids,
- token_ids=token_ids,
- gen_config=gen_config,
- adapter_names=adapter_names,
- input_embeddings=input_embeddings,
- input_embedding_ranges=input_embedding_ranges,
- keep_cache=keep_cache)
-
- async def async_add_session(self, session_id: int):
- """Add new session."""
- return await self.engine_instance._async_try_add_session(session_id)
-
- def add_session(self, session_id: int):
- """Add new session."""
- return self.engine_instance._try_add_session(session_id)
-
- async def async_cancel(self, session_id: int):
- """Stop the given session."""
- return await self.engine_instance.async_cancel(session_id)
-
- def cancel(self, session_id: int):
- """Add new session."""
- return self.engine_instance.cancel(session_id)
-
- async def async_end(self, session_id: int):
- """End the given session."""
- return await self.engine_instance.async_end(session_id)
-
- def end(self, session_id: int):
- """Add new session."""
- return self.engine_instance.end(session_id)
diff --git a/lmdeploy/pytorch/engine/engine_checker.py b/lmdeploy/pytorch/engine/engine_checker.py
new file mode 100644
index 0000000000..5b0cc9865c
--- /dev/null
+++ b/lmdeploy/pytorch/engine/engine_checker.py
@@ -0,0 +1,79 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from lmdeploy.messages import PytorchEngineConfig
+
+from ..check_env.adapter import AdapterChecker
+from ..check_env.base import BaseChecker
+from ..check_env.model import ModelChecker
+from ..check_env.torch import TorchChecker
+from ..check_env.transformers import TransformersChecker
+
+
+class EngineChecker(BaseChecker):
+ """check transformers is available."""
+
+ def __init__(self,
+ model_path: str,
+ engine_config: PytorchEngineConfig,
+ trust_remote_code: bool = True,
+ logger=None):
+ super().__init__(logger)
+ logger = self.get_logger()
+
+ self.engine_config = engine_config
+
+ dtype = engine_config.dtype
+ device_type = engine_config.device_type
+
+ # pytorch
+ torch_checker = TorchChecker(logger=logger)
+
+ if device_type == 'cuda':
+ # triton
+ from ..check_env.triton import TritonChecker
+ triton_checker = TritonChecker(logger=logger)
+ triton_checker.register_required_checker(torch_checker)
+ self.register_required_checker(triton_checker)
+ else:
+ # deeplink
+ from ..check_env.deeplink import DeeplinkChecker
+ dl_checker = DeeplinkChecker(device_type, logger=logger)
+ self.register_required_checker(dl_checker)
+ self.register_required_checker(torch_checker)
+
+ # transformers
+
+ # model
+ trans_checker = TransformersChecker()
+ model_checker = ModelChecker(model_path=model_path,
+ trust_remote_code=trust_remote_code,
+ dtype=dtype,
+ device_type=device_type,
+ logger=logger)
+ model_checker.register_required_checker(torch_checker)
+ model_checker.register_required_checker(trans_checker)
+ self.register_required_checker(model_checker)
+
+ # adapters
+ adapters = engine_config.adapters
+ if adapters is not None:
+ adapter_paths = list(adapters.values())
+ for adapter in adapter_paths:
+ adapter_checker = AdapterChecker(adapter, logger=logger)
+ self.register_required_checker(adapter_checker)
+
+ def check(self):
+ """check."""
+ engine_config = self.engine_config
+
+ if engine_config.thread_safe:
+ self.log_and_exit(
+ mod_name='Engine',
+ message='thread safe mode is no longer supported.\n'
+ 'Read https://github.com/InternLM/lmdeploy/blob/main/docs/en/advance/pytorch_multithread.md for more details.', # noqa: E501
+ )
+
+ if engine_config.max_batch_size <= 0:
+ self.log_and_exit(
+ mod_name='Engine',
+ message='max_batch_size should be'
+ f' greater than 0, but got {engine_config.max_batch_size}')
diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py
index 455ab1ccb3..5cf1366783 100644
--- a/lmdeploy/pytorch/engine/engine_instance.py
+++ b/lmdeploy/pytorch/engine/engine_instance.py
@@ -1,16 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import List
+from typing import Any, Dict, List
from lmdeploy.messages import EngineOutput, GenerationConfig
from lmdeploy.utils import get_logger
-from ..messages import (InputEmbeddingRangeType, InputEmbeddings,
- InputEmbeddingType, SamplingParam)
+from ..messages import SamplingParam
from .engine import Engine
from .request import RequestSender, RequestType, Response, ResponseType
logger = get_logger('lmdeploy')
+InputMultiModalType = List[Dict[str, Any]]
+
def _check_resp(resp: Response, state: ResponseType, warning_msg: str = None):
"""check if response has state."""
@@ -42,8 +43,8 @@ async def async_try_add_session(req_sender: RequestSender, session_id: int):
async def async_end(req_sender: RequestSender, session_id: int):
"""End the given session."""
- await req_sender.async_send_async(
- RequestType.END_SESSION, dict(session_id=session_id, response=False))
+ req_sender.send_async(RequestType.END_SESSION,
+ dict(session_id=session_id, response=False))
async def async_cancel(req_sender: RequestSender, session_id: int):
@@ -114,15 +115,13 @@ def _try_add_session(self, session_id: int):
"""
return try_add_session(self.req_sender, session_id)
- async def async_stream_infer(
- self,
- session_id: int,
- input_ids: List[int],
- gen_config: GenerationConfig = None,
- adapter_name: str = None,
- input_embeddings: InputEmbeddingType = None,
- input_embedding_ranges: InputEmbeddingRangeType = None,
- **kwargs):
+ async def async_stream_infer(self,
+ session_id: int,
+ input_ids: List[int],
+ gen_config: GenerationConfig = None,
+ multimodal: InputMultiModalType = None,
+ adapter_name: str = None,
+ **kwargs):
"""Send stream inference request.
Args:
@@ -141,52 +140,37 @@ async def async_stream_infer(
return
gen_config = gen_config or GenerationConfig()
sampling_param = SamplingParam.from_gen_config(gen_config=gen_config)
- await self.req_sender.async_send_async(
- RequestType.ADD_SESSION, dict(session_id=session_id,
- response=False))
- input_embeddings_new: List[InputEmbeddings] = None
- if input_embeddings is not None and len(input_embeddings) > 0:
- assert len(input_embeddings) == len(input_embedding_ranges)
- input_embeddings_new = [
- InputEmbeddings(emb, rg[0], rg[1])
- for emb, rg in zip(input_embeddings, input_embedding_ranges)
- ]
- msg = dict(token_ids=input_ids,
- session_id=session_id,
- sampling_param=sampling_param,
- adapter_name=adapter_name,
- input_embeddings=input_embeddings_new,
- mrope_position_ids=kwargs.get('mrope_position_ids'),
- mrope_position_delta=kwargs.get('mrope_position_delta'),
- cross_attention_states=kwargs.get('cross_attention_states'))
- req_id = await self.req_sender.async_send_async(
- RequestType.ADD_MESSAGE, msg)
+ self.req_sender.send_async(RequestType.ADD_SESSION,
+ dict(session_id=session_id, response=False))
+ msg = dict(
+ token_ids=input_ids,
+ session_id=session_id,
+ sampling_param=sampling_param,
+ adapter_name=adapter_name,
+ input_multimodals=multimodal,
+ )
+ resp = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg)
- token_ids = []
while True:
- resp = await self.req_sender.async_recv(req_id)
+ resp = await self.req_sender.async_recv(resp)
- if resp.req_id != req_id:
- continue
if resp.type == ResponseType.SUCCESS:
- token_ids += resp.data['token_ids']
+ token_ids = resp.data['token_ids'].tolist()
yield EngineOutput(resp.type, token_ids, len(token_ids))
elif resp.type == ResponseType.FINISH:
- token_ids += resp.data['token_ids']
+ token_ids = resp.data['token_ids'].tolist()
yield EngineOutput(resp.type, token_ids, len(token_ids))
break
else:
yield EngineOutput(resp.type, [], 0)
break
- async def async_infer(
- self,
- session_id: int,
- input_ids: List[int] = None,
- gen_config: GenerationConfig = None,
- input_embeddings: InputEmbeddingType = None,
- input_embedding_ranges: InputEmbeddingRangeType = None,
- **kwargs):
+ async def async_infer(self,
+ session_id: int,
+ input_ids: List[int] = None,
+ multimodal: InputMultiModalType = None,
+ gen_config: GenerationConfig = None,
+ **kwargs):
"""Send inference request.
Args:
@@ -200,13 +184,11 @@ async def async_infer(
int: The number of the output tokens.
"""
token_ids = []
- async for outputs in self.async_stream_infer(
- session_id,
- input_ids,
- gen_config=gen_config,
- input_embeddings=input_embeddings,
- input_embedding_ranges=input_embedding_ranges,
- **kwargs):
+ async for outputs in self.async_stream_infer(session_id,
+ input_ids,
+ multimodal=multimodal,
+ gen_config=gen_config,
+ **kwargs):
status, tmp_ids = outputs.status, outputs.token_ids
if status not in [ResponseType.SUCCESS, ResponseType.FINISH]:
return EngineOutput(status, token_ids, len(token_ids))
@@ -217,10 +199,9 @@ async def async_infer(
def stream_infer(self,
session_id: int,
input_ids: List[int],
+ multimodal: InputMultiModalType = None,
gen_config: GenerationConfig = None,
adapter_name: str = None,
- input_embeddings: InputEmbeddingType = None,
- input_embedding_ranges: InputEmbeddingRangeType = None,
**kwargs):
"""Send stream inference request.
@@ -241,14 +222,12 @@ def stream_infer(self,
def __call_async():
"""call async."""
- coro_gen = self.async_stream_infer(
- session_id,
- input_ids,
- gen_config,
- adapter_name,
- input_embeddings=input_embeddings,
- input_embedding_ranges=input_embedding_ranges,
- **kwargs)
+ coro_gen = self.async_stream_infer(session_id,
+ input_ids,
+ multimodal=multimodal,
+ gen_config=gen_config,
+ adapter_name=adapter_name,
+ **kwargs)
while True:
try:
yield self.req_sender.run_until_complete(
@@ -256,53 +235,13 @@ def __call_async():
except StopAsyncIteration:
break
- if not self.req_sender.is_thread_safe():
- yield from __call_async()
- return
-
- gen_config = gen_config or GenerationConfig()
- sampling_param = SamplingParam.from_gen_config(gen_config=gen_config)
- self.req_sender.send_async(RequestType.ADD_SESSION,
- dict(session_id=session_id, response=False))
- input_embeddings_new: List[InputEmbeddings] = None
- if input_embeddings is not None and len(input_embeddings) > 0:
- assert len(input_embeddings) == len(input_embedding_ranges)
- input_embeddings_new = [
- InputEmbeddings(emb, rg[0], rg[1])
- for emb, rg in zip(input_embeddings, input_embedding_ranges)
- ]
- msg = dict(
- token_ids=input_ids,
- session_id=session_id,
- sampling_param=sampling_param,
- adapter_name=adapter_name,
- input_embeddings=input_embeddings_new,
- )
- req_id = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg)
-
- token_ids = []
- while True:
- resp = self.req_sender.recv(req_id)
-
- if resp.req_id != req_id:
- continue
- if resp.type == ResponseType.SUCCESS:
- token_ids += resp.data['token_ids']
- yield EngineOutput(resp.type, token_ids, len(token_ids))
- elif resp.type == ResponseType.FINISH:
- token_ids += resp.data['token_ids']
- yield EngineOutput(resp.type, token_ids, len(token_ids))
- break
- else:
- yield EngineOutput(resp.type, [], 0)
- break
+ yield from __call_async()
def infer(self,
session_id: int,
input_ids: List[int] = None,
+ multimodal: InputMultiModalType = None,
gen_config: GenerationConfig = None,
- input_embeddings: InputEmbeddingType = None,
- input_embedding_ranges: InputEmbeddingRangeType = None,
**kwargs):
"""Send inference request.
@@ -317,13 +256,11 @@ def infer(self,
int: The number of the output tokens.
"""
token_ids = []
- for outputs in self.stream_infer(
- session_id,
- input_ids,
- gen_config=gen_config,
- input_embeddings=input_embeddings,
- input_embedding_ranges=input_embedding_ranges,
- **kwargs):
+ for outputs in self.stream_infer(session_id,
+ input_ids,
+ multimodal=multimodal,
+ gen_config=gen_config,
+ **kwargs):
status, tmp_ids = outputs.status, outputs.token_ids
if status not in [ResponseType.SUCCESS, ResponseType.FINISH]:
return EngineOutput(status, token_ids, len(token_ids))
@@ -331,127 +268,6 @@ def infer(self,
return EngineOutput(0, token_ids, len(token_ids))
- async def async_batched_infer(
- self,
- session_ids: List[int],
- token_ids: List[List[int]] = None,
- gen_config: GenerationConfig = None,
- adapter_names: List[str] = None,
- keep_cache: bool = False,
- input_embeddings: List[InputEmbeddingType] = None,
- input_embedding_ranges: List[InputEmbeddingRangeType] = None,
- ):
- """Send inference request.
-
- Args:
- session_ids (List[int]): The session id.
- token_ids (List[int]): The input token ids.
- gen_config (GenerationConfig): The sampling parameters.
- adapter_names (List[str]): The name of the adapters.
- keep_cache (bool): Keep kv cache after infer.
-
- Returns:
- int: Error flags. 0 if success.
- List[int]: The streaming output tokens.
- int: The number of the output tokens.
- """
- batch_size = len(token_ids)
- assert len(session_ids) == batch_size
- if adapter_names is not None:
- assert len(adapter_names) == batch_size
- else:
- adapter_names = [None for _ in range(batch_size)]
-
- if input_embeddings is not None:
- assert len(input_embeddings) == batch_size
- assert len(input_embedding_ranges) == batch_size
- else:
- input_embeddings = [None] * batch_size
- input_embedding_ranges = [None] * batch_size
-
- async def _add_sessions(session_ids):
- for session_id in session_ids:
- await self._async_try_add_session(session_id)
-
- async def _add_messages(session_ids, token_ids, adapter_names,
- input_embeddings, input_embedding_ranges):
- add_msgs = []
- sampling_param = SamplingParam.from_gen_config(gen_config)
- for session_id, token_id, adapter_name, input_emb, input_ranges in zip( # noqa: E501
- session_ids, token_ids, adapter_names, input_embeddings,
- input_embedding_ranges):
- cur_input_embeddings: List[InputEmbeddings] = None
- if input_emb is not None and len(input_emb) > 0:
- assert len(input_emb) == len(input_ranges)
- cur_input_embeddings = [
- InputEmbeddings(emb, rg[0], rg[1])
- for emb, rg in zip(input_emb, input_ranges)
- ]
- msg = dict(
- token_ids=token_id,
- session_id=session_id,
- sampling_param=sampling_param,
- adapter_name=adapter_name,
- input_embeddings=cur_input_embeddings,
- )
- add_msgs.append(msg)
- req_types = [RequestType.ADD_MESSAGE] * batch_size
- req_ids = await self.req_sender.async_batched_send_async(
- req_types, data=add_msgs)
- return req_ids
-
- await _add_sessions(session_ids)
- req_ids = await _add_messages(session_ids, token_ids, adapter_names,
- input_embeddings, input_embedding_ranges)
-
- # receive messages
- req_idx_map = dict(zip(req_ids, range(len(req_ids))))
- output_token_ids = [list() for _ in req_ids]
- status = 0
- finish_count = batch_size
- while finish_count:
- resp = await self.req_sender.async_recv_any()
- if resp.req_id not in req_ids:
- continue
- idx = req_idx_map[resp.req_id]
- token_ids = output_token_ids[idx]
- if resp.type == ResponseType.SUCCESS:
- token_ids += resp.data['token_ids']
- elif resp.type == ResponseType.FINISH:
- token_ids += resp.data['token_ids']
- if not keep_cache:
- session_id = session_ids[idx]
- await self.async_end(session_id=session_id)
- finish_count -= 1
- else:
- logger.error(f'Unexpected response: {resp.type}')
- status = 1
- break
-
- output_token_len = [len(token_ids) for token_ids in output_token_ids]
- return EngineOutput(status, output_token_ids, output_token_len)
-
- def batched_infer(
- self,
- session_ids: List[int],
- token_ids: List[List[int]] = None,
- gen_config: GenerationConfig = None,
- adapter_names: List[str] = None,
- keep_cache: bool = False,
- input_embeddings: List[InputEmbeddingType] = None,
- input_embedding_ranges: List[InputEmbeddingRangeType] = None,
- ):
- """batched infer."""
- coro = self.async_batched_infer(
- session_ids,
- token_ids,
- gen_config=gen_config,
- adapter_names=adapter_names,
- input_embeddings=input_embeddings,
- input_embedding_ranges=input_embedding_ranges,
- keep_cache=keep_cache)
- return self.req_sender.run_until_complete(coro)
-
async def async_end(self, session_id: int):
"""End the given session."""
return await async_end(self.req_sender, session_id)
@@ -470,8 +286,7 @@ def cancel(self, session_id: int):
def decode(self,
input_ids,
- input_embeddings: List[InputEmbeddingType] = None,
- input_embedding_ranges: List[InputEmbeddingRangeType] = None,
+ multimodal: List[InputMultiModalType] = None,
steps: List[int] = None,
sequence_start: bool = True,
sequence_end: bool = True,
@@ -481,10 +296,8 @@ def decode(self,
Args:
input_ids (numpy.ndarray): the batch of input token ids
steps (List[int]): the offset of the k/v cache
- input_embeddings (List[List[Union[torch.Tensor, np.ndarray]]]):
- embeddings features
- input_embedding_ranges: (List[List[Tuple[int, int]]]):
- the begin/end offsets of input_embeddings to input_ids
+ multimodal (List[InputMultiModalType]):
+ multimodals inputs.
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
adapter_names (List[str]): The name of the adapters.
@@ -494,39 +307,30 @@ def decode(self,
batch_size = len(input_ids)
def __add_messages(session_ids, input_ids, adapter_names,
- input_embeddings, input_embedding_ranges):
+ input_multimodals):
add_msgs = []
sampling_param = SamplingParam(max_new_tokens=0)
batch_size = len(input_ids)
- if input_embeddings is None:
- input_embeddings = [None] * batch_size
- input_embedding_ranges = [None] * batch_size
- for (session_id, token_id, adapter_name, input_emb,
- input_ranges) in zip(session_ids, input_ids, adapter_names,
- input_embeddings,
- input_embedding_ranges):
+ if input_multimodals is None:
+ input_multimodals = [None] * batch_size
+ for (session_id, token_id, adapter_name,
+ in_mm) in zip(session_ids, input_ids, adapter_names,
+ input_multimodals):
if len(token_id) > self.max_input_len:
raise RuntimeError(
f'Expect input length<={self.max_input_len} '
f'but get {len(token_id)}')
- cur_input_embeddings: List[InputEmbeddings] = None
- if input_emb is not None and len(input_emb) > 0:
- assert len(input_emb) == len(input_ranges)
- cur_input_embeddings = [
- InputEmbeddings(emb, rg[0], rg[1])
- for emb, rg in zip(input_emb, input_ranges)
- ]
msg = dict(token_ids=token_id,
session_id=session_id,
sampling_param=sampling_param,
adapter_name=adapter_name,
- input_embeddings=cur_input_embeddings,
+ input_multimodals=in_mm,
return_logits=True)
add_msgs.append(msg)
req_types = [RequestType.ADD_MESSAGE] * batch_size
- req_ids = self.req_sender.batched_send_async(req_types,
- data=add_msgs)
- return req_ids
+ resps = self.req_sender.batched_send_async(req_types,
+ data=add_msgs)
+ return resps
if steps is not None:
assert batch_size == len(steps)
@@ -536,13 +340,6 @@ def __add_messages(session_ids, input_ids, adapter_names,
else:
adapter_names = [None] * batch_size
- if input_embeddings is not None:
- assert len(input_embeddings) == batch_size
- assert len(input_embedding_ranges) == batch_size
- else:
- input_embeddings = [None] * batch_size
- input_embedding_ranges = [None] * batch_size
-
session_ids = tuple(range(batch_size))
if sequence_start:
for sid in session_ids:
@@ -550,21 +347,14 @@ def __add_messages(session_ids, input_ids, adapter_names,
dict(session_id=sid))
self._try_add_session(sid)
- req_ids = __add_messages(session_ids, input_ids, adapter_names,
- input_embeddings, input_embedding_ranges)
- req_idx_map = dict(zip(req_ids, range(len(req_ids))))
-
- finish_count = batch_size
- ret = [None] * batch_size
- while finish_count > 0:
- resp = self.req_sender.recv_any()
- if resp.req_id not in req_ids:
- continue
+ resps = __add_messages(session_ids, input_ids, adapter_names,
+ multimodal)
+ ret = []
+ for resp in resps:
+ resp = self.req_sender.recv(resp)
assert resp.type == ResponseType.FINISH
- idx = req_idx_map[resp.req_id]
- ret[idx] = resp.data['logits']
- finish_count -= 1
+ ret.append(resp.data['logits'])
ret = pad_sequence(ret, True)
diff --git a/lmdeploy/pytorch/engine/input_process.py b/lmdeploy/pytorch/engine/input_process.py
new file mode 100644
index 0000000000..7f442e153b
--- /dev/null
+++ b/lmdeploy/pytorch/engine/input_process.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional
+
+from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs
+
+TypeModelMetas = Dict[str, Any]
+
+InputMultiModalType = List[Dict[str, Any]]
+
+
+@dataclass
+class PreprocessInputResult:
+ """results of preprocess input."""
+ input_ids: List[int]
+ input_multimodals: Optional[MultiModalInputs] = None
+ model_metas: Optional[TypeModelMetas] = None
+
+
+class BaseModelInputProcessor(ABC):
+ """processor of model inputs."""
+
+ @abstractmethod
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_mms: InputMultiModalType = None,
+ **kwargs) -> PreprocessInputResult:
+ """preprocess input."""
+ raise NotImplementedError('Not implemented.')
+
+
+class DefaultModelInputProcessor(BaseModelInputProcessor):
+ """default model input processor."""
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_mms: MultiModalInputs = None,
+ **kwargs) -> PreprocessInputResult:
+ """preprocess input."""
+ return PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=input_mms,
+ )
diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py
index 24cb336d71..f7ca9c5116 100644
--- a/lmdeploy/pytorch/engine/logits_process.py
+++ b/lmdeploy/pytorch/engine/logits_process.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import asyncio
import json
from dataclasses import asdict, dataclass
from typing import Dict, List, Optional, Tuple
@@ -298,9 +299,15 @@ def __init__(self,
self.ignore_eos = ignore_eos
self.tokenizer = tokenizer
- def __call__(self, all_ids: torch.LongTensor,
- guided_input_ids: torch.LongTensor,
- scores: torch.FloatTensor) -> torch.FloatTensor:
+ async def _wait_stream_once(self):
+ """wait stream once."""
+ stream = torch.cuda.current_stream()
+ if not stream.query():
+ await asyncio.sleep(0)
+
+ async def __call__(self, all_ids: torch.LongTensor,
+ guided_input_ids: torch.LongTensor,
+ scores: torch.FloatTensor) -> torch.FloatTensor:
r"""
Args:
all_ids (torch.LongTensor): All the token ids.
@@ -320,6 +327,7 @@ def __call__(self, all_ids: torch.LongTensor,
custom_logits_processors = self.sampling_inputs.logits_processors
if any(custom_logits_processors):
+ await self._wait_stream_once()
scores = _apply_custom_logits_processors(custom_logits_processors,
all_ids, scores)
@@ -343,8 +351,10 @@ def __call__(self, all_ids: torch.LongTensor,
stop_mask = torch.where(self.ignore_eos[:, None], stop_mask, False)
scores = _process_bad_words_(scores, stop_words, stop_mask)
- scores = _guided_sampling(sampling_inputs.response_formats, scores,
- guided_input_ids, self.tokenizer)
+ if guided_input_ids is not None:
+ await self._wait_stream_once()
+ scores = _guided_sampling(sampling_inputs.response_formats, scores,
+ guided_input_ids, self.tokenizer)
return scores
@torch.inference_mode()
diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py
index 03fb083265..9b98dfb1fe 100644
--- a/lmdeploy/pytorch/engine/model_agent.py
+++ b/lmdeploy/pytorch/engine/model_agent.py
@@ -135,21 +135,26 @@ def model_forward(
stream = stream or torch.cuda.current_stream()
with torch.cuda.stream(stream):
# forward
- inputs = inputs.to_device('cuda')
ctx_mgr = model.ctx_mgr
context = ctx_mgr.build_context(
inputs=inputs,
+ model_config=cache_engine.model_config,
world_size=world_size,
kv_caches=cache_engine.gpu_cache,
kv_quant_policy=cache_engine.cache_config.quant_policy,
)
with ctx_mgr.context(context):
+ model_metas = None
+ model_metas = model.update_model_metas(
+ past_key_values=cache_engine.gpu_cache,
+ context=context,
+ )
input_dict = model.prepare_inputs_for_generation(
past_key_values=cache_engine.gpu_cache,
context=context,
)
output = model(**input_dict)
- return dict(hidden_states=output)
+ return dict(hidden_states=output, model_metas=model_metas)
SwapMap = Dict[int, int]
@@ -177,6 +182,10 @@ def get_logits(self, hidden_states: torch.Tensor):
"""get logits of model output."""
raise NotImplementedError('Not implemented.')
+ def get_input_processor(self):
+ """get input processor."""
+ raise NotImplementedError('Not implemented.')
+
class BaseModelAgent(AutoModelAgent):
"""Base model agent.
@@ -288,8 +297,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
await asyncio.sleep(0)
- while not self.stream.query():
- await asyncio.sleep(0)
return output
async def score_proposal(self, inputs: ModelInputs, swap_in_map: SwapMap,
@@ -358,6 +365,10 @@ def get_spec_logits(self, hidden_states_list: List[torch.Tensor]):
"""get logits of model output."""
return self.speculative_model.get_logits(hidden_states_list)
+ def get_input_processor(self):
+ """get input processor.."""
+ return self.patched_model.get_input_processor()
+
@torch.inference_mode()
def _tp_build_model(
@@ -443,14 +454,26 @@ def _broadcast_config(cache_config):
return patched_model, cache_engine, cache_config
-def _broadcast_inputs(rank: int, inputs: Any, stream: torch.cuda.Stream):
+def _broadcast_inputs(rank: int, inputs: Any, group: dist.group,
+ stream: torch.cuda.Stream):
"""get input tensor parallel."""
# broadcast meta info
if rank != 0:
inputs = [None, None, None]
+ else:
+ device_inputs = inputs[0]
+ meta_inputs = device_inputs.to_device('meta')
+ inputs[0] = meta_inputs
with torch.cuda.stream(stream):
- dist.broadcast_object_list(inputs)
+ dist.broadcast_object_list(inputs, group=group)
+ if rank == 0:
+ device_inputs.broadcast()
+ else:
+ device_inputs = inputs[0].broadcast()
+
+ inputs[0] = device_inputs
+
return inputs
@@ -463,6 +486,7 @@ def _tp_model_loop(
adapters: Dict[str, str],
world_size: int,
barrier: mp.Barrier,
+ cpu_group: dist.group,
):
"""Start model loops for tensor parallel model inference.
@@ -488,11 +512,12 @@ def _tp_model_loop(
while True:
barrier.wait()
inputs, swap_in_map, swap_out_map = _broadcast_inputs(
- rank, None, stream)
+ rank, None, cpu_group, stream)
cache_swapping(cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
+ inputs = inputs.to_device('cuda')
model_forward(
patched_model,
@@ -524,10 +549,13 @@ def _start_tp_process(proc_id: int,
try:
from lmdeploy.pytorch.check_env import check_env_deeplink
check_env_deeplink(device_context.device_type)
+ timeout = timedelta(days=35600)
dist.init_process_group('nccl',
rank=rank,
world_size=world_size,
- timeout=timedelta(days=35600))
+ timeout=timeout)
+ cpu_group = dist.new_group(timeout=timeout, backend='gloo')
+ kwargs['cpu_group'] = cpu_group
dist_ctx = DistContext(rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
with get_dist_manager().context(dist_ctx), get_device_manager(
@@ -706,12 +734,15 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig,
rank = 0
try:
+ timeout = timedelta(days=35600)
dist.init_process_group('nccl',
rank=rank,
world_size=world_size,
- timeout=timedelta(days=35600))
+ timeout=timeout)
+ cpu_group = dist.new_group(timeout=timeout, backend='gloo')
dist_ctx = DistContext(rank=rank, world_size=world_size)
self._dist_ctx = dist_ctx
+ self._cpu_group = cpu_group
except Exception as e:
from traceback import print_exc
logger.error(f'Rank[{rank}] failed.')
@@ -782,7 +813,8 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap,
self.mp_bar.wait()
rank = 0
_broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map],
- self.stream)
+ self._cpu_group, self.stream)
+
cache_swapping(self.cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
@@ -850,8 +882,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
await asyncio.sleep(0)
- while not self.stream.query():
- await asyncio.sleep(0)
return output
async def tree_decoding(self, inputs: ModelInputs, swap_in_map: SwapMap,
@@ -897,6 +927,10 @@ def get_spec_logits(self, hidden_states_list: List[torch.Tensor]):
"""get logits of model output."""
return self.speculative_model.get_logits(hidden_states_list)
+ def get_input_processor(self):
+ """get input processor.."""
+ return self.patched_model.get_input_processor()
+
def _exit_handler(agent: TPModelAgent):
if hasattr(agent, 'patched_model'):
@@ -926,7 +960,7 @@ def build_model_agent(model_path: str,
custom_module_map (str): customized nn module map
"""
model_config = ModelConfig.from_pretrained(
- model_path, trust_remote_code=trust_remote_code, dtype=dtype)
+ model_path, trust_remote_code=trust_remote_code, dtype=dtype, tp=tp)
model_config.custom_module_map = custom_module_map
speculative_model_config = None
if speculative_model is not None:
diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py
index 18bd2193d4..0d20deb907 100644
--- a/lmdeploy/pytorch/engine/request.py
+++ b/lmdeploy/pytorch/engine/request.py
@@ -2,8 +2,6 @@
import asyncio
import enum
from dataclasses import dataclass, field
-from queue import Queue
-from threading import Lock, Thread
from typing import Any, Awaitable, Callable, Dict, List
from lmdeploy.messages import ResponseType
@@ -12,25 +10,6 @@
logger = get_logger('lmdeploy')
-def _raise_exception_on_finish(task: asyncio.Task) -> None:
- try:
- task.result()
- except asyncio.CancelledError:
- return
- except Exception as e:
- logger.exception(f'Engine loop failed with error: {e}')
-
-
-def _ignore_exception_on_finish(task: asyncio.Task) -> None:
- try:
- task.result()
- except asyncio.CancelledError:
- return
- except Exception as exc:
- logger.debug(f'task: {task.get_name()} ended.')
- logger.debug(f'task: {task.get_name()} exception: {exc}')
-
-
class RequestType(enum.Enum):
"""Request type."""
@@ -43,24 +22,24 @@ class RequestType(enum.Enum):
@dataclass
-class Request:
- """Request."""
+class Response:
+ """Response."""
- type: RequestType
+ type: ResponseType
sender_id: int
- req_id: int
+ event: asyncio.Event
data: Any = None
+ err_msg: str = ''
@dataclass
-class Response:
- """Response."""
+class Request:
+ """Request."""
- type: ResponseType
+ type: RequestType
sender_id: int
- req_id: int
data: Any = None
- err_msg: str = ''
+ resp: Response = None
ReqList = List[Request]
@@ -85,28 +64,20 @@ class RequestSender:
Args:
sender_id (int): The id of the sender
"""
-
sender_id: int
manager: 'RequestManager'
resp_dict: Dict[int, List[Response]] = field(default_factory=dict)
- _next_req_id: int = 0
_resp_que: asyncio.Queue = None
- _resp_thread_que: Queue = None
- _thread_safe: bool = False
@classmethod
def new(cls, sender_id: int, manager: 'RequestManager'):
"""new."""
obj = cls(sender_id=sender_id, manager=manager)
- obj._thread_safe = manager.is_thread_safe()
return obj
@property
def resp_que(self):
"""response queue."""
- thread_safe = self.is_thread_safe()
- if thread_safe:
- return self.manager.responses
if self._resp_que is not None:
return self._resp_que
if self.manager._loop_task is None:
@@ -119,27 +90,11 @@ def req_que(self):
"""request queue."""
return self.manager.requests
- @property
- def resp_thread_que(self):
- """response threadsafe queue."""
- if self._resp_thread_que is None:
- self._resp_thread_que = Queue()
- return self._resp_thread_que
-
- @property
- def req_thread_que(self):
- """request threadsafe queue."""
- return self.manager.thread_requests
-
@property
def event_loop(self):
"""get event loop."""
return self.manager.event_loop
- def is_thread_safe(self):
- """is thread safe."""
- return self._thread_safe
-
def is_loop_alive(self):
"""is loop alive."""
return self.manager.is_loop_alive()
@@ -148,203 +103,72 @@ def run_until_complete(self, future: Awaitable):
"""run untile complete."""
return self.manager.run_until_complete(future)
- def _resp_get(self):
- """resp_que.get."""
- timeout = 1.0
- que = self.resp_thread_que
- not_empty = que.not_empty
- with not_empty:
- while not que._qsize():
- not_empty.wait(timeout)
- return que.get_nowait()
-
- async def _async_resp_get(self):
- """get resp.
-
- Different behavior in threadsafe mode.
- """
- timeout = 1
-
- async def __no_threadsafe_get():
- while True:
- try:
- return await asyncio.wait_for(self.resp_que.get(), timeout)
- except asyncio.TimeoutError:
- if not self.manager.is_loop_alive():
- logger.debug('Engine loop is not alive.')
- exit(1)
- continue
- except Exception as e:
- logger.exception(
- f'sender[{self.sender_id}] get response failed: {e}')
- raise e
-
- if self.is_thread_safe():
- ret = self._resp_get()
- await asyncio.sleep(0)
- return ret
- else:
- return await __no_threadsafe_get()
-
def _req_put(self, reqs: Any):
- """req put."""
- self.req_thread_que.put(reqs)
-
- async def _async_req_put(self, reqs: Any):
- """async rq_que put.
-
- Different behavior in threadsafe mode.
- """
- if self.is_thread_safe():
- self._req_put(reqs)
- await asyncio.sleep(0)
- else:
- await self.req_que.put(reqs)
-
- def _prefetch_resps(self):
- """prefetch from resp que.
-
- Different behavior in threadsafe mode.
- """
- if self.is_thread_safe():
- resp_que = self.resp_thread_que
- else:
- resp_que = self.resp_que
- num_resps = resp_que.qsize()
- for _ in range(num_resps):
- resp: Response = resp_que.get_nowait()
- req_id = resp.req_id
- self._push_resp(req_id, resp)
-
- def _push_resp(self, req_id: int, resp: Response):
- """push response."""
- self.resp_dict.setdefault(req_id, [])
- self.resp_dict[req_id].append(resp)
-
- def _pop_resp(self, req_id: int, default: Any = None):
- """pop response."""
- if req_id not in self.resp_dict:
- return default
- resps = self.resp_dict[req_id]
- ret = resps.pop(0)
- if len(resps) == 0:
- self.resp_dict.pop(req_id)
- return ret
+ """async rq_que put."""
+ self.req_que.put_nowait(reqs)
def _gather_request(self, req_types: List[RequestType], data: List[Any]):
"""gather requests."""
- if self.manager._loop_task is None and not self.is_thread_safe():
+ if self.manager._loop_task is None:
self.manager.create_loop_task()
assert len(req_types) == len(data)
- batch_size = len(req_types)
-
- req_ids = list(range(self._next_req_id,
- self._next_req_id + batch_size))
- self._next_req_id += batch_size
-
- reqs = [
- Request(type=rtype,
- sender_id=self.sender_id,
- req_id=req_id,
- data=rdata)
- for req_id, rtype, rdata in zip(req_ids, req_types, data)
- ]
- return req_ids, reqs
- async def async_batched_send_async(self, req_types: List[RequestType],
- data: List[Any]):
- """Batched send request asynchronize."""
- req_ids, reqs = self._gather_request(req_types, data)
- await self._async_req_put(reqs)
- return req_ids
-
- async def async_send_async(self, req_type: RequestType, data: Any):
- """send request asynchronize."""
- return (await self.async_batched_send_async(req_types=[req_type],
- data=[data]))[0]
+ reqs = []
+ resps = []
+ for rtype, rdata in zip(req_types, data):
+ event = asyncio.Event()
+ resp = Response(type=ResponseType.HANDLER_NOT_EXIST,
+ sender_id=self.sender_id,
+ event=event,
+ data=None,
+ err_msg=None)
+ req = Request(type=rtype,
+ sender_id=self.sender_id,
+ data=rdata,
+ resp=resp)
+ resps.append(resp)
+ reqs.append(req)
+ return resps, reqs
def batched_send_async(self, req_types: List[RequestType],
- data: List[Any]) -> List[int]:
- """Batched send request asynchronize.
-
- Different behavior in threadsafe mode.
- """
- if not self.is_thread_safe():
- coro = self.async_batched_send_async(req_types, data)
- return self.run_until_complete(coro)
-
- req_ids, reqs = self._gather_request(req_types, data)
+ data: List[Any]):
+ """Batched send request asynchronize."""
+ resps, reqs = self._gather_request(req_types, data)
self._req_put(reqs)
- return req_ids
+ return resps
- def send_async(self, req_type: RequestType, data: Any) -> int:
+ def send_async(self, req_type: RequestType, data: Any):
"""send request asynchronize."""
return self.batched_send_async(req_types=[req_type], data=[data])[0]
- async def async_recv_any(self) -> Response:
- """receive any response."""
- self._prefetch_resps()
- for req_id in self.resp_dict:
- ret = self._pop_resp(req_id, default=None)
- if ret is not None:
- return ret
- return await self._async_resp_get()
-
- def recv_any(self) -> Response:
- """receive any response."""
- coro = self.async_recv_any()
- return self.run_until_complete(coro)
-
- def recv_all(self, req_id: int, block: bool = True):
- """revceive all response with req_id."""
- self._prefetch_resps()
- resps = self.resp_dict.pop(req_id, [])
- return resps
-
- async def async_recv(self, req_id: int) -> Response:
+ async def async_recv(self, resp: Response) -> Response:
"""receive response of given request id async."""
- ret = self._pop_resp(req_id, default=None)
- if ret is not None:
- return ret
-
- # check resp que
- while True:
- resp: Response = await self._async_resp_get()
- if resp.req_id != req_id:
- self._push_resp(req_id, resp)
- else:
- return resp
-
- def recv(self, req_id: int) -> Response:
- """receive response of given request id.
-
- Different behavior in threadsafe mode.
- """
- if not self.is_thread_safe():
- coro = self.async_recv(req_id)
- return self.run_until_complete(coro)
-
- ret = self._pop_resp(req_id, default=None)
- if ret is not None:
- return ret
-
- # check resp que
- while True:
- resp: Response = self._resp_get()
- if resp.req_id != req_id:
- self._push_resp(req_id, resp)
- else:
- return resp
+ event = resp.event
+ while not event.is_set():
+ try:
+ await asyncio.wait_for(event.wait(), 1)
+ except asyncio.TimeoutError:
+ if self.is_loop_alive():
+ continue
+ logger.debug('Engine main loop failed.')
+ break
+ event.clear()
+ return resp
+
+ def recv(self, resp: Response) -> Response:
+ """receive response of given request id."""
+ coro = self.async_recv(resp)
+ return self.run_until_complete(coro)
async def async_send(self, req_type: RequestType, data: Any):
"""send and receive synchronize."""
- req_id = await self.async_send_async(req_type, data)
- return await self.async_recv(req_id)
+ resp = self.send_async(req_type, data)
+ return await self.async_recv(resp)
def send(self, req_type: RequestType, data: Any) -> Response:
"""send and receive synchronize."""
- req_id = self.send_async(req_type, data)
- return self.recv(req_id)
+ resp = self.send_async(req_type, data)
+ return self.recv(resp)
def response_callback(self, resp: Response):
"""response callback."""
@@ -354,7 +178,7 @@ def response_callback(self, resp: Response):
class RequestManager:
"""Request manager."""
- def __init__(self, thread_safe: bool = False):
+ def __init__(self):
self.senders: Dict[int, RequestSender] = dict()
self.callbacks: Dict[RequestType, Callable] = dict()
self.request_priority: List[RequestType] = [
@@ -365,17 +189,7 @@ def __init__(self, thread_safe: bool = False):
self.requests: asyncio.Queue = None
self._loop_task: asyncio.Future = None
self._loop_coro: Callable = None
- self._thread_safe = thread_safe
self._next_sender_id = 0
- self._mutex = Lock()
- self._loop_thread: Thread = None
-
- self.thread_requests: Queue = None
- # every sender has it's own responses, this responses is
- # only used in thread safe mode.
- self.responses: asyncio.Queue = None
- if thread_safe:
- self.thread_requests = Queue()
def create_loop_task(self):
"""create coro task."""
@@ -385,7 +199,6 @@ def create_loop_task(self):
'Please set loop task with manager.start_loop')
loop_unshielded = event_loop.create_task(self._loop_coro(),
name='EngineMainLoop')
- loop_unshielded.add_done_callback(_raise_exception_on_finish)
self._loop_task = asyncio.shield(loop_unshielded)
self.requests = asyncio.Queue()
return self._loop_task
@@ -398,105 +211,17 @@ def event_loop(self):
else:
return self._loop_task.get_loop()
- def is_thread_safe(self):
- """is thread safe."""
- return self._thread_safe
-
def start_loop(self, loop: asyncio.Task):
"""start main loop."""
self._loop_coro = loop
- def __get_thread_reqs():
- """get thread reqs."""
- num_reqs = self.thread_requests.qsize()
- reqs = []
- for _ in range(num_reqs):
- tmp_reqs = self.thread_requests.get_nowait()
- if isinstance(tmp_reqs, Request):
- tmp_reqs = [tmp_reqs]
- reqs += tmp_reqs
- return reqs
-
- async def __async_get_req(event_loop):
- """async get request."""
- que = self.thread_requests
- not_empty = que.not_empty
- with not_empty:
- while not que._qsize():
- await event_loop.run_in_executor(None, not_empty.wait, 1.0)
- reqs = que.get_nowait()
- if isinstance(reqs, Request):
- reqs = [reqs]
- return reqs
-
- async def __req_loop():
- """req loop."""
- event_loop = asyncio.get_event_loop()
- while True:
- # get reqs
- reqs = __get_thread_reqs()
- if len(reqs) == 0:
- reqs = await __async_get_req(event_loop)
- self.requests.put_nowait(reqs)
-
- def __put_thread_resps(resps: List[Response]):
- """put thread resps."""
- for resp in resps:
- sender = self.senders.get(resp.sender_id, None)
- if sender is None:
- continue
- sender.resp_thread_que.put_nowait(resp)
-
- async def __resp_loop():
- """resp loop."""
- while True:
- num_resps = self.responses.qsize()
-
- if num_resps == 0:
- resps = [await self.responses.get()]
- else:
- resps = []
- for _ in range(num_resps):
- resps.append(self.responses.get_nowait())
- __put_thread_resps(resps)
- await asyncio.sleep(0)
-
- def __run_forever(event_loop: asyncio.BaseEventLoop):
- """run forever."""
- logger.debug('start thread run forever.')
- asyncio.set_event_loop(event_loop)
- self.responses = asyncio.Queue()
- self.create_loop_task()
- req_loop = event_loop.create_task(__req_loop(),
- name='RunForeverReqLoop')
- req_loop.add_done_callback(_ignore_exception_on_finish)
- resp_loop = event_loop.create_task(__resp_loop(),
- name='RunForeverRespLoop')
- resp_loop.add_done_callback(_ignore_exception_on_finish)
- self.event_loop.run_forever()
-
- if self.is_thread_safe():
- event_loop = asyncio.new_event_loop()
- self._loop_thread = Thread(target=__run_forever,
- args=(event_loop, ),
- daemon=True)
- self._loop_thread.start()
+ def stop_loop(self):
+ if self.is_loop_alive():
+ self._loop_task.cancel()
def is_loop_alive(self):
"""check if main loop is alive."""
- def __check_threadsafe():
- if self._loop_thread is None:
- return False
- if not self._loop_thread.is_alive():
- return False
- if self._loop_task is None:
- return False
- return not self._loop_task.done()
-
- if self.is_thread_safe():
- return __check_threadsafe()
-
if self._loop_task is None:
logger.debug('loop task has not been created.')
return False
@@ -508,12 +233,11 @@ def __check_threadsafe():
def build_sender(self):
"""create a new sender."""
- with self._mutex:
- sender_id = self._next_sender_id
- self._next_sender_id += 1
- new_sender = RequestSender.new(sender_id, self)
- self.senders[sender_id] = new_sender
- return new_sender
+ sender_id = self._next_sender_id
+ self._next_sender_id += 1
+ new_sender = RequestSender.new(sender_id, self)
+ self.senders[sender_id] = new_sender
+ return new_sender
def has_requests(self):
"""has unprocessed request."""
@@ -521,16 +245,27 @@ def has_requests(self):
return False
return not self.requests.empty()
- def get_all_requests(self) -> Dict[RequestType, Request]:
+ async def get_all_requests(self) -> Dict[RequestType, Request]:
"""get all requests in current queue."""
num_reqs = self.requests.qsize()
reqs: ReqList = []
- for _ in range(num_reqs):
- elem = self.requests.get_nowait()
+
+ def __proc_reqs(elem):
+ """proc reqs."""
+ nonlocal reqs
if isinstance(elem, Request):
elem = [elem]
reqs += elem
+ if num_reqs == 0:
+ elem = await self.requests.get()
+ __proc_reqs(elem)
+ num_reqs = self.requests.qsize()
+
+ for _ in range(num_reqs):
+ elem = self.requests.get_nowait()
+ __proc_reqs(elem)
+
# gather requests
reqs_by_type: Dict[RequestType, Request] = dict(
(t, []) for t in RequestType)
@@ -548,11 +283,7 @@ def set_request_priority(self, priority: List[RequestType]):
def response(self, resp: Response):
"""send response."""
- if resp.sender_id not in self.senders:
- logger.warning(f'sender {resp.sender_id} not exist. '
- f'Send {resp} failed.')
- return
- self.senders[resp.sender_id].response_callback(resp)
+ resp.event.set()
def process_request(self, req_type: RequestType, reqs: ReqList, **kwargs):
"""process reqs with given req type."""
@@ -563,19 +294,18 @@ def process_request(self, req_type: RequestType, reqs: ReqList, **kwargs):
else:
# TODO: send error message
for req in reqs:
- resp = Response(ResponseType.HANDLER_NOT_EXIST,
- sender_id=req.sender_id,
- req_id=req.req_id,
- err_msg=(f'callback for {req_type}'
- ' not exists.'))
+ resp = req.resp
+ resp.type = ResponseType.HANDLER_NOT_EXIST
+ resp.err_msg = (f'callback for {req_type}'
+ ' not exists.')
self.response(resp)
- def step(self, **kwargs):
+ async def step(self, **kwargs):
"""handle requests.
Should only be called in loop task.
"""
- reqs_by_type = self.get_all_requests()
+ reqs_by_type = await self.get_all_requests()
# handle requests
for req_type in self.request_priority:
diff --git a/lmdeploy/pytorch/kernels/cuda/__init__.py b/lmdeploy/pytorch/kernels/cuda/__init__.py
index 3790cf0f66..b62ddef80a 100644
--- a/lmdeploy/pytorch/kernels/cuda/__init__.py
+++ b/lmdeploy/pytorch/kernels/cuda/__init__.py
@@ -9,6 +9,7 @@
from .multinomial_sampling import multinomial_sampling
from .pagedattention import paged_attention_fwd
from .rms_norm import rms_norm
+from .w8a8_fused_moe import fused_moe_w8a8
from .w8a8_triton_kernels import (matmul_kernel_dynamic_quant,
per_channel_quant, per_token_quant_int8,
rms_norm_dynamic_quant)
@@ -28,4 +29,5 @@
'rms_norm_dynamic_quant',
'flash_attention_fwd',
'flatten_kv_cache',
+ 'fused_moe_w8a8',
]
diff --git a/lmdeploy/pytorch/kernels/cuda/awq_kernels.py b/lmdeploy/pytorch/kernels/cuda/awq_kernels.py
index 13b9841e9b..395b0c427e 100644
--- a/lmdeploy/pytorch/kernels/cuda/awq_kernels.py
+++ b/lmdeploy/pytorch/kernels/cuda/awq_kernels.py
@@ -2,210 +2,95 @@
import triton
from triton import language as tl
-from .triton_utils import get_kernel_meta, wrap_jit_func
-
def get_cuda_autotune_config():
return [
- # most used
- triton.Config(
- {
- 'BLOCK_SIZE_M': 128,
- 'BLOCK_SIZE_N': 64,
- 'BLOCK_SIZE_K': 32,
- 'GROUP_SIZE_M': 8
- },
- num_stages=4,
- num_warps=4),
- triton.Config(
- {
- 'BLOCK_SIZE_M': 64,
- 'BLOCK_SIZE_N': 64,
- 'BLOCK_SIZE_K': 64,
- 'GROUP_SIZE_M': 8
- },
- num_stages=4,
- num_warps=4),
- # # other
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 128,
- # 'BLOCK_SIZE_N': 256,
- # 'BLOCK_SIZE_K': 64,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=3,
- # num_warps=8),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 64,
- # 'BLOCK_SIZE_N': 256,
- # 'BLOCK_SIZE_K': 32,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 128,
- # 'BLOCK_SIZE_N': 128,
- # 'BLOCK_SIZE_K': 32,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 64,
- # 'BLOCK_SIZE_N': 128,
- # 'BLOCK_SIZE_K': 32,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 128,
- # 'BLOCK_SIZE_N': 32,
- # 'BLOCK_SIZE_K': 32,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 64,
- # 'BLOCK_SIZE_N': 32,
- # 'BLOCK_SIZE_K': 32,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=5,
- # num_warps=2),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 32,
- # 'BLOCK_SIZE_N': 64,
- # 'BLOCK_SIZE_K': 32,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=5,
- # num_warps=2),
- # # Good config for fp8 inputs.
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 128,
- # 'BLOCK_SIZE_N': 256,
- # 'BLOCK_SIZE_K': 128,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=3,
- # num_warps=8),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 256,
- # 'BLOCK_SIZE_N': 128,
- # 'BLOCK_SIZE_K': 128,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=3,
- # num_warps=8),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 256,
- # 'BLOCK_SIZE_N': 64,
- # 'BLOCK_SIZE_K': 128,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 64,
- # 'BLOCK_SIZE_N': 256,
- # 'BLOCK_SIZE_K': 128,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 128,
- # 'BLOCK_SIZE_N': 128,
- # 'BLOCK_SIZE_K': 128,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 128,
- # 'BLOCK_SIZE_N': 64,
- # 'BLOCK_SIZE_K': 64,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 128,
- # 'BLOCK_SIZE_N': 32,
- # 'BLOCK_SIZE_K': 64,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
+ triton.Config({
+ 'BLOCK_SIZE_N': 64,
+ 'GROUP_SIZE_M': 8,
+ },
+ num_stages=3,
+ num_warps=4),
]
@triton.jit
-def _get_unpacked_order(offs_n, elem_per_int: tl.constexpr):
- """get unpacked order."""
- origin_order = offs_n % elem_per_int
- unpacked_order = (origin_order & 1) * 4 + origin_order // 2
- return unpacked_order
+def _dequant_s4_to_f16x2(weight, shift: tl.constexpr, is_top: tl.constexpr):
+
+ immLut: tl.constexpr = (0xf0 & 0xcc) | 0xaa
+ BOTTOM_MASK: tl.constexpr = 0x000f000f
+ TOP_MASK: tl.constexpr = 0x00f000f0
+ I4s_TO_F16s_MAGIC_NUM: tl.constexpr = 0x64006400
+ FP16_TOP_MAGIC_NUM: tl.constexpr = 0x64006400
+ ONE_SIXTEENTH: tl.constexpr = 0x2c002c00
+ NEG_64: tl.constexpr = 0xd400d400
+
+ if shift:
+ weight = weight >> 8
+
+ if is_top:
+ return tl.inline_asm_elementwise("""{
+ .reg .b32 tmp;
+ lop3.b32 tmp, $2, $3, $4, $5;
+ fma.rn.f16x2 tmp, tmp, $6, $7;
+ mov.b32 {$0, $1}, tmp;
+ }""",
+ '=h,=h,r,n,n,n,r,r',
+ args=[
+ weight, TOP_MASK,
+ I4s_TO_F16s_MAGIC_NUM, immLut,
+ ONE_SIXTEENTH, NEG_64
+ ],
+ dtype=(tl.float16, tl.float16),
+ is_pure=True,
+ pack=1)
+ else:
+ return tl.inline_asm_elementwise("""{
+ .reg .b32 tmp;
+ lop3.b32 tmp, $2, $3, $4, $5;
+ sub.f16x2 tmp, tmp, $6;
+ mov.b32 {$0, $1}, tmp;
+ }""",
+ '=h,=h,r,n,n,n,r',
+ args=[
+ weight, BOTTOM_MASK,
+ I4s_TO_F16s_MAGIC_NUM, immLut,
+ FP16_TOP_MAGIC_NUM
+ ],
+ dtype=(tl.float16, tl.float16),
+ is_pure=True,
+ pack=1)
@triton.jit
-def _broadcast_pack(weight, width: tl.constexpr):
- """broadcast pack."""
- broadcast_tmp = tl.arange(0, width)
+def _unpack_weight(weight):
+ """unpack weight."""
+ # broadcast and shift
+ width: tl.constexpr = 8
BLOCK_SIZE_K: tl.constexpr = weight.shape[0]
BLOCK_SIZE_QN: tl.constexpr = weight.shape[1]
BLOCK_SIZE_N: tl.constexpr = BLOCK_SIZE_QN * width
- weight = tl.broadcast(weight[:, :, None], broadcast_tmp[None, None, :])
- weight = tl.reshape(weight, (BLOCK_SIZE_K, BLOCK_SIZE_N))
- return weight
+ w0, w1 = _dequant_s4_to_f16x2(weight, False, False)
+ w2, w3 = _dequant_s4_to_f16x2(weight, False, True)
+ w4, w5 = _dequant_s4_to_f16x2(weight, True, False)
+ w6, w7 = _dequant_s4_to_f16x2(weight, True, True)
-@triton.jit
-def _unpack_weight(weight, order):
- """unpack weight."""
- weight = _broadcast_pack(weight, 8)
- weight = weight >> (order * 4)
- # cast to float16
- immLut = (0xf0 & 0xcc) | 0xaa
- BOTTOM_MASK = 0xf
- I4s_TO_F16s_MAGIC_NUM = 0x6400
- FP16_TOP_MAGIC_NUM = 0x6400
- weight = tl.inline_asm_elementwise(
- """lop3.b32 $1, $1, $2, $3, $4;
- sub.f16x2 $1, $1, $5;
- mov.b32 {$0, _}, $1;""",
- '=h, r, n, n, n, r', [
- weight, BOTTOM_MASK, I4s_TO_F16s_MAGIC_NUM, immLut,
- FP16_TOP_MAGIC_NUM
- ],
- dtype=tl.float16,
- is_pure=False,
- pack=1)
- return weight
+ w04 = tl.join(w0, w4)
+ w15 = tl.join(w1, w5)
+ w26 = tl.join(w2, w6)
+ w37 = tl.join(w3, w7)
+ w0246 = tl.join(w04, w26)
+ w1357 = tl.join(w15, w37)
+ weight = tl.join(w0246, w1357)
+
+ return weight.reshape(BLOCK_SIZE_K, BLOCK_SIZE_N)
@triton.autotune(
configs=get_cuda_autotune_config(),
- key=['M_NEXT_P2', 'N', 'K'],
+ key=['N', 'K'],
)
-@wrap_jit_func
@triton.jit
def awq_linear_kernel(
a_ptr,
@@ -225,12 +110,9 @@ def awq_linear_kernel(
stride_zk: tl.constexpr,
stride_zn: tl.constexpr, #
stride_cm,
- stride_ck: tl.constexpr,
stride_cn: tl.constexpr,
# Meta-parameters
- M_NEXT_P2: tl.constexpr,
- Q_GROUP_SIZE: tl.constexpr,
- SPLIT_K_ITERS: tl.constexpr,
+ SPLIT_K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, #
@@ -239,19 +121,13 @@ def awq_linear_kernel(
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
- ELEM_PER_INT = 8
- if Q_GROUP_SIZE > BLOCK_SIZE_K:
- GROUP_SIZE_K: tl.constexpr = BLOCK_SIZE_K
- else:
- GROUP_SIZE_K: tl.constexpr = Q_GROUP_SIZE
- K_PER_GROUP: tl.constexpr = Q_GROUP_SIZE // GROUP_SIZE_K
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
+ kid = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
- split_kid = tl.program_id(axis=1)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
@@ -267,8 +143,7 @@ def awq_linear_kernel(
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
BLOCK_SIZE_QN: tl.constexpr = BLOCK_SIZE_N // 8
offs_wn = pid_n * BLOCK_SIZE_QN + tl.arange(0, BLOCK_SIZE_QN)
- offs_k = tl.arange(0, GROUP_SIZE_K)
- unpacked_order = _get_unpacked_order(offs_bn, ELEM_PER_INT)
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am +
offs_k[None, :] * stride_ak)
qw_ptrs = qw_ptr + (offs_k[:, None] * stride_wk +
@@ -276,49 +151,52 @@ def awq_linear_kernel(
s_ptrs = s_ptr + offs_bn * stride_sn
qz_ptrs = qz_ptr + offs_wn * stride_zn
- # split k
- NUM_K_BLOCKS = K // GROUP_SIZE_K
- K_PER_SPLIT = tl.cdiv(NUM_K_BLOCKS, SPLIT_K_ITERS)
- k_start = split_kid * K_PER_SPLIT
- k_last = min(k_start + K_PER_SPLIT, NUM_K_BLOCKS)
- a_ptrs += k_start * GROUP_SIZE_K * stride_ak
- qw_ptrs += k_start * GROUP_SIZE_K * stride_wk
- qg_id = k_start // K_PER_GROUP
-
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- s = tl.zeros((1, BLOCK_SIZE_N), dtype=s_ptrs.dtype.element_ty)
- zs = tl.zeros((1, BLOCK_SIZE_N), dtype=s_ptrs.dtype.element_ty)
+
+ k_start = kid
+ k_last = K // BLOCK_SIZE_K
# prefetch
- next_qw = tl.load(qw_ptrs)
- qw_ptrs += GROUP_SIZE_K * stride_wk
+ a_ptrs += k_start * BLOCK_SIZE_K * stride_ak
+ qw_ptrs += k_start * BLOCK_SIZE_K * stride_wk
+ s_ptrs += k_start * stride_sk
+ qz_ptrs += k_start * stride_zk
+ qw = tl.load(qw_ptrs)
+ qz = tl.load(qz_ptrs)[None, :]
+ s = tl.load(s_ptrs)[None, :]
+ qw_ptrs += SPLIT_K * BLOCK_SIZE_K * stride_wk
+ s_ptrs += SPLIT_K * stride_sk
+ qz_ptrs += SPLIT_K * stride_zk
+
+ for k in tl.range(k_start, k_last, SPLIT_K, num_stages=3):
- for k in range(k_start, k_last):
+ # load a
a = tl.load(a_ptrs)
- qw = next_qw
- if k + 1 < k_last:
- next_qw = tl.load(qw_ptrs)
- w = _unpack_weight(qw, unpacked_order)
-
- if k == k_start or k % K_PER_GROUP == 0:
- s = tl.load(s_ptrs + qg_id * stride_sk)[None, :]
- qz = tl.load(qz_ptrs + qg_id * stride_zk)[None, :]
- qg_id += 1
- z = _unpack_weight(qz, unpacked_order)
- zs = -z * s
- b = w * s + zs
+
+ # unpack b
+ z = _unpack_weight(qz)
+ w = _unpack_weight(qw)
+ b = (w - z) * s
+
+ # load next q
+ mask = k + SPLIT_K < k_last
+ qz = tl.load(qz_ptrs, mask=mask)[None, :]
+ s = tl.load(s_ptrs, mask=mask)[None, :]
+ qw = tl.load(qw_ptrs, mask=mask)
# We accumulate along the K dimension.
- accumulator += tl.dot(a, b)
+ accumulator = tl.dot(a, b, acc=accumulator)
# Advance the ptrs to the next K block.
- a_ptrs += GROUP_SIZE_K * stride_ak
- qw_ptrs += GROUP_SIZE_K * stride_wk
+ a_ptrs += SPLIT_K * BLOCK_SIZE_K * stride_ak
+ qw_ptrs += SPLIT_K * BLOCK_SIZE_K * stride_wk
+ s_ptrs += SPLIT_K * stride_sk
+ qz_ptrs += SPLIT_K * stride_zk
c = accumulator.to(tl.float16)
@@ -329,11 +207,11 @@ def awq_linear_kernel(
c_ptrs = c_ptr + stride_cm * offs_cm[:,
None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
- if stride_ck > 0:
- c_ptrs += split_kid * stride_ck
- tl.store(c_ptrs, c, mask=c_mask)
+
+ if SPLIT_K > 1:
+ tl.atomic_add(c_ptrs, c, mask=c_mask, sem='relaxed', scope='gpu')
else:
- tl.atomic_add(c_ptrs, c, mask=c_mask)
+ tl.store(c_ptrs, c, mask=c_mask)
def awq_linear(x, qweight, scales, qzeros):
@@ -341,18 +219,24 @@ def awq_linear(x, qweight, scales, qzeros):
M = x.size(0)
K = qweight.size(0)
N = scales.size(1)
- SPLIT_K_ITERS = 4
group_size = K // scales.size(0)
+ SPLIT_K = max(1, K // 4096)
def grid(META):
"""grid."""
- return (triton.cdiv(M, META['BLOCK_SIZE_M']) *
- triton.cdiv(N, META['BLOCK_SIZE_N']), SPLIT_K_ITERS)
+ return (
+ triton.cdiv(M, META['BLOCK_SIZE_M']) *
+ triton.cdiv(N, META['BLOCK_SIZE_N']),
+ SPLIT_K,
+ )
- out = scales.new_empty(M, SPLIT_K_ITERS, N)
- M_NEXT_P2 = triton.next_power_of_2(M)
+ if SPLIT_K > 1:
+ out = scales.new_zeros(M, N)
+ else:
+ out = scales.new_empty(M, N)
- kernel_meta = get_kernel_meta(x)
+ BLOCK_SIZE_M = triton.next_power_of_2(M)
+ BLOCK_SIZE_M = max(16, min(128, BLOCK_SIZE_M))
awq_linear_kernel[grid](
# Pointers to matrices
x,
@@ -373,12 +257,11 @@ def grid(META):
stride_zk=qzeros.stride(0),
stride_zn=qzeros.stride(1), #
stride_cm=out.stride(0),
- stride_ck=out.stride(1),
- stride_cn=out.stride(2),
+ stride_cn=out.stride(1),
# Meta-parameters
- M_NEXT_P2=M_NEXT_P2,
- Q_GROUP_SIZE=group_size,
- SPLIT_K_ITERS=SPLIT_K_ITERS,
- **kernel_meta)
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
+ BLOCK_SIZE_K=group_size,
+ SPLIT_K=SPLIT_K,
+ )
- return out.sum(1)
+ return out
diff --git a/lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py b/lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py
new file mode 100644
index 0000000000..4907d92ac5
--- /dev/null
+++ b/lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py
@@ -0,0 +1,344 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# modify from: https://github.com/vllm-project/vllm
+import torch
+import triton
+import triton.language as tl
+
+from .activation import silu_and_mul
+from .blocked_gemm_fp8 import quant_fp8
+from .fused_moe import _get_sorted_idx, _make_intermediate, _renormalize
+
+
+def get_cuda_autotune_config():
+ return [
+ triton.Config({
+ 'BLOCK_SIZE_M': 128,
+ 'BLOCK_SIZE_N': 64,
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config({
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 128,
+ },
+ num_stages=4,
+ num_warps=4),
+ ]
+
+
+@triton.autotune(
+ configs=get_cuda_autotune_config(),
+ key=['N', 'K', 'M_NP2'],
+ warmup=10,
+ rep=25,
+)
+@triton.jit
+def fused_moe_blocked_f8_kernel(
+ A,
+ A_scale,
+ B,
+ B_scale,
+ C,
+ SortedIdx,
+ ExpStart,
+ ExpEnd,
+ Weights,
+ N: tl.constexpr,
+ K: tl.constexpr,
+ group_ak: tl.constexpr,
+ group_bk: tl.constexpr,
+ group_bn: tl.constexpr,
+ stride_am: tl.constexpr,
+ stride_ak: tl.constexpr,
+ stride_asm,
+ stride_ask: tl.constexpr,
+ stride_be: tl.constexpr,
+ stride_bn: tl.constexpr,
+ stride_bk: tl.constexpr,
+ stride_bse: tl.constexpr,
+ stride_bsk: tl.constexpr,
+ stride_bsn: tl.constexpr,
+ stride_cm,
+ stride_cn: tl.constexpr,
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr,
+ M_NP2: tl.constexpr,
+ ENABLE_WEIGHTS: tl.constexpr,
+ top_k: tl.constexpr,
+ expert_offset: tl.constexpr,
+ reindex_a: tl.constexpr,
+ reindex_c: tl.constexpr,
+):
+ """fused moe kernel."""
+ exp_id = tl.program_id(1)
+ pid = tl.program_id(0)
+
+ exp_start = tl.load(ExpStart + exp_id + expert_offset)
+ exp_end = tl.load(ExpEnd + exp_id + expert_offset)
+ M = exp_end - exp_start
+ if M <= 0:
+ return
+
+ num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+
+ if GROUP_SIZE_M == 1:
+ pid_m = pid % num_pid_m
+ pid_n = pid // num_pid_m
+ else:
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N:
+ return
+
+ offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ mask_sid = offs_sid < exp_end
+ sid = tl.load(SortedIdx + offs_sid, mask=mask_sid, other=0)
+
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ if reindex_a:
+ offs_am = sid // top_k
+ else:
+ offs_am = offs_sid
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
+ as_ptrs = A_scale + offs_am
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+ offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N),
+ BLOCK_SIZE_N)
+
+ # deepseek has 160 experts, exp index would overflow int32
+ exp_id = exp_id.to(tl.int64)
+ exp_off = stride_be * exp_id
+ b_ptrs = B + exp_off + (offs_k[:, None] * stride_bk +
+ offs_bn[None, :] * stride_bn)
+
+ offs_bsn = pid_n * BLOCK_SIZE_N // group_bn
+ as_ptrs = A_scale + offs_am * stride_asm
+ bs_ptrs = B_scale + stride_bse * exp_id + offs_bsn * stride_bsn
+
+ acc_scale = tl.load(as_ptrs) * tl.load(bs_ptrs)
+ acc_ratio = 1 / acc_scale
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
+ # load scales
+ k_start = (k + 1) * BLOCK_SIZE_K
+ offs_ksa = k_start // group_ak
+ offs_ksb = k_start // group_bk
+ a_scale = tl.load(as_ptrs + offs_ksa * stride_ask,
+ mask=k_start < K,
+ other=1.0)
+ b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk,
+ mask=k_start < K,
+ other=1.0)
+
+ # load ab
+ a = tl.load(a_ptrs,
+ mask=mask_sid[:, None] &
+ (offs_k[None, :] < K - k * BLOCK_SIZE_K),
+ other=0.0)
+ b = tl.load(b_ptrs,
+ mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
+ other=0.0)
+
+ # mma
+ accumulator = tl.dot(a, b, acc=accumulator * acc_ratio[:, None])
+
+ # update scales and ratio
+ new_acc_scale = a_scale * b_scale
+ acc_ratio = acc_scale / new_acc_scale
+ acc_scale = new_acc_scale
+
+ a_ptrs += BLOCK_SIZE_K * stride_ak
+ b_ptrs += BLOCK_SIZE_K * stride_bk
+
+ c = accumulator * (acc_ratio * acc_scale)[:, None]
+
+ if ENABLE_WEIGHTS:
+ weight = tl.load(Weights + sid, mask=mask_sid)
+ c = c * weight[:, None].to(c.dtype)
+
+ c = c.to(C.dtype.element_ty)
+
+ if reindex_c:
+ offs_cm = sid
+ else:
+ offs_cm = offs_sid
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_bn[None, :]
+ tl.store(c_ptrs, c, mask=mask_sid[:, None])
+
+
+def fused_moe_blocked_fp8_kernel_launcher(
+ A: torch.Tensor,
+ A_scale: torch.Tensor,
+ B: torch.Tensor,
+ B_scale: torch.Tensor,
+ C: torch.Tensor,
+ sorted_idx: torch.Tensor,
+ exp_start: torch.Tensor,
+ exp_end: torch.Tensor,
+ weights: torch.Tensor,
+ enable_weights: bool = False,
+ top_k: int = 1,
+ num_tokens: int = None,
+ expert_offset: int = 0,
+ reindex_a: bool = True,
+ reindex_c: bool = True,
+):
+ """fused moe kernel launcher."""
+
+ if num_tokens is None:
+ num_tokens = A.size(0)
+ M_NP2 = triton.next_power_of_2(num_tokens)
+ M_NP2 = max(64, M_NP2)
+ E, N, K = B.shape
+
+ assert A.dim() == 2
+ assert A_scale.dim() == 2
+ assert B.dim() == 3
+ assert B_scale.dim() == 3
+
+ assert K % A_scale.size(1) == 0
+ assert K % B_scale.size(2) == 0
+ assert N % B_scale.size(1) == 0
+
+ group_ak = K // A_scale.size(1)
+ group_bk = K // B_scale.size(2)
+ group_bn = N // B_scale.size(1)
+
+ def _grid_fn(META):
+ grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) *
+ triton.cdiv(N, META['BLOCK_SIZE_N']), E)
+ return grid
+
+ A = A.flatten(0, -2)
+ C = C.flatten(0, -2)
+
+ BLOCK_SIZE_K = group_bk
+ GROUP_SIZE_M = 8
+ grid = _grid_fn
+ fused_moe_blocked_f8_kernel[grid](
+ A,
+ A_scale,
+ B,
+ B_scale,
+ C,
+ sorted_idx,
+ exp_start,
+ exp_end,
+ weights,
+ N=N,
+ K=K,
+ group_ak=group_ak,
+ group_bk=group_bk,
+ group_bn=group_bn,
+ stride_am=A.stride(0),
+ stride_ak=A.stride(1),
+ stride_asm=A_scale.stride(0),
+ stride_ask=A_scale.stride(1),
+ stride_be=B.stride(0),
+ stride_bn=B.stride(1),
+ stride_bk=B.stride(2),
+ stride_bse=B_scale.stride(0),
+ stride_bsn=B_scale.stride(1),
+ stride_bsk=B_scale.stride(2),
+ stride_cm=C.stride(0),
+ stride_cn=C.stride(1),
+ ENABLE_WEIGHTS=enable_weights,
+ top_k=top_k,
+ expert_offset=expert_offset,
+ reindex_a=reindex_a,
+ reindex_c=reindex_c,
+ M_NP2=M_NP2,
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
+ GROUP_SIZE_M=GROUP_SIZE_M,
+ )
+
+
+def fused_moe_blocked_fp8(input: torch.Tensor,
+ input_scale: torch.Tensor,
+ w1: torch.Tensor,
+ w1_scale: torch.Tensor,
+ w2: torch.Tensor,
+ w2_scale: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ topk: int,
+ out_dtype: torch.dtype = torch.float16,
+ expert_offset: int = 0,
+ num_experts: int = None,
+ renormalize: bool = False) -> torch.Tensor:
+ """fused moe."""
+ device = input.device
+ M = input.size(0)
+ E, N, _ = w1.shape
+ if num_experts is None:
+ num_experts = E
+ full_exp = num_experts == E
+ group_size = input.size(-1) // input_scale.size(-1)
+
+ topk_weights = _renormalize(topk_weights, renormalize)
+ sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts)
+
+ intermediate_cache1 = _make_intermediate((M, topk, N),
+ dtype=out_dtype,
+ device=device,
+ zeros=not full_exp)
+ # gate and up
+ fused_moe_blocked_fp8_kernel_launcher(
+ input,
+ input_scale,
+ w1,
+ w1_scale,
+ intermediate_cache1,
+ sorted_idx=sorted_idx,
+ exp_start=exp_start,
+ exp_end=exp_end,
+ weights=topk_weights,
+ enable_weights=False,
+ top_k=topk,
+ num_tokens=M,
+ expert_offset=expert_offset,
+ reindex_a=True,
+ reindex_c=False,
+ )
+
+ # activate
+ intermediate_cache1 = intermediate_cache1.flatten(0, -2)
+ gate_cache = silu_and_mul(intermediate_cache1)
+ del intermediate_cache1
+ gate_cache, gate_scale = quant_fp8(gate_cache,
+ group_size,
+ dtype=input.dtype)
+
+ intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]),
+ dtype=out_dtype,
+ device=device,
+ zeros=not full_exp)
+ # down
+ fused_moe_blocked_fp8_kernel_launcher(
+ gate_cache,
+ gate_scale,
+ w2,
+ w2_scale,
+ intermediate_cache2,
+ sorted_idx=sorted_idx,
+ exp_start=exp_start,
+ exp_end=exp_end,
+ weights=topk_weights,
+ enable_weights=True,
+ top_k=1,
+ num_tokens=M,
+ expert_offset=expert_offset,
+ reindex_a=False,
+ reindex_c=True,
+ )
+
+ ret = intermediate_cache2.sum(dim=1)
+ return ret
diff --git a/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py b/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py
new file mode 100644
index 0000000000..9f992bcfef
--- /dev/null
+++ b/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py
@@ -0,0 +1,237 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import triton
+import triton.language as tl
+from torch import Tensor
+
+
+@triton.jit
+def _quant_fp8_kernel(
+ a_ptr,
+ out_ptr,
+ scale_ptr,
+ fp8_min: tl.constexpr,
+ fp8_max: tl.constexpr,
+ stride_am,
+ stride_ak: tl.constexpr,
+ stride_om,
+ stride_ok: tl.constexpr,
+ stride_sm,
+ stride_sg: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+):
+ """quant fp8 kernel."""
+ group_id = tl.program_id(0)
+ m_id = tl.program_id(1)
+
+ g_offs = group_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
+
+ a_ptrs = a_ptr + m_id * stride_am + g_offs * stride_ak
+ o_ptrs = out_ptr + m_id * stride_om + g_offs * stride_ok
+ s_ptr = scale_ptr + m_id * stride_sm + group_id * stride_sg
+
+ rfp8_max = 1 / fp8_max
+
+ a = tl.load(a_ptrs).to(tl.float32)
+ scale = tl.max(tl.abs(a)) * rfp8_max
+ out = a / scale
+
+ out = tl.clamp(out, fp8_min, fp8_max)
+ out = out.to(out_ptr.dtype.element_ty)
+
+ tl.store(o_ptrs, out)
+ tl.store(s_ptr, scale)
+
+
+def quant_fp8(A: Tensor,
+ group_size: int,
+ dtype: torch.dtype = torch.float8_e4m3fn):
+ """quant online."""
+ assert A.dim() == 2
+ M, K = A.shape
+ assert K % group_size == 0
+ num_groups = K // group_size
+
+ finfo = torch.finfo(dtype)
+ fmin = finfo.min
+ fmax = finfo.max
+
+ out = torch.empty_like(A, dtype=dtype)
+ scales = A.new_empty(M, num_groups, dtype=torch.float32)
+ grid = (num_groups, M)
+ num_warps = 4
+ num_stages = 1
+ _quant_fp8_kernel[grid](
+ A,
+ out,
+ scales,
+ fp8_min=fmin,
+ fp8_max=fmax,
+ stride_am=A.stride(0),
+ stride_ak=A.stride(1),
+ stride_om=out.stride(0),
+ stride_ok=out.stride(1),
+ stride_sm=scales.stride(0),
+ stride_sg=scales.stride(1),
+ GROUP_SIZE=group_size,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+
+ return out, scales
+
+
+@triton.autotune(configs=[
+ triton.Config({
+ 'BLOCK_M': 64,
+ 'BLOCK_N': 128,
+ }, num_stages=3, num_warps=4),
+ triton.Config({
+ 'BLOCK_M': 128,
+ 'BLOCK_N': 64,
+ }, num_stages=3, num_warps=4)
+],
+ key=['N', 'K'],
+ warmup=5,
+ rep=10)
+@triton.jit
+def _gemm_fp8_kernel(
+ A,
+ a_scale_ptr,
+ B,
+ b_scale_ptr,
+ C,
+ M,
+ N: tl.constexpr,
+ K: tl.constexpr,
+ group_ak: tl.constexpr,
+ group_bk: tl.constexpr,
+ group_bn: tl.constexpr,
+ stride_am,
+ stride_ak: tl.constexpr,
+ stride_asm,
+ stride_ask: tl.constexpr,
+ stride_bk: tl.constexpr,
+ stride_bn: tl.constexpr,
+ stride_bsk: tl.constexpr,
+ stride_bsn: tl.constexpr,
+ stride_cm,
+ stride_cn: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr,
+):
+ """gemm fp8 kernel."""
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_M)
+ num_pid_n = tl.cdiv(N, BLOCK_N)
+ num_pid_in_group = GROUP_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
+ offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
+ offs_k = tl.arange(0, BLOCK_K)
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
+
+ offs_bsn = pid_n * BLOCK_N // group_bn
+ as_ptrs = a_scale_ptr + offs_am * stride_asm
+ bs_ptrs = b_scale_ptr + offs_bsn * stride_bsn
+
+ acc_scale = tl.load(as_ptrs) * tl.load(bs_ptrs)
+ acc_ratio = 1 / acc_scale
+ accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ # load scales
+ k_start = (k + 1) * BLOCK_K
+ offs_ksa = k_start // group_ak
+ offs_ksb = k_start // group_bk
+ a_scale = tl.load(as_ptrs + offs_ksa * stride_ask,
+ mask=k_start < K,
+ other=1.0)
+ b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk,
+ mask=k_start < K,
+ other=1.0)
+
+ # load ab
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
+
+ # mma
+ accumulator = tl.dot(a, b, acc=accumulator * acc_ratio[:, None])
+
+ # update scales and ratio
+ new_acc_scale = a_scale * b_scale
+ acc_ratio = acc_scale / new_acc_scale
+ acc_scale = new_acc_scale
+
+ a_ptrs += BLOCK_K * stride_ak
+ b_ptrs += BLOCK_K * stride_bk
+ c = accumulator * (acc_ratio * acc_scale)[:, None]
+
+ offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
+ tl.store(c_ptrs, c, mask=c_mask)
+
+
+def blocked_gemm_fp8(A: Tensor,
+ A_scale: Tensor,
+ B: Tensor,
+ B_scale: torch.Tensor,
+ out_dtype: torch.dtype = torch.float16):
+ """gemm fp8."""
+
+ def grid(META):
+ return (triton.cdiv(M, META['BLOCK_M']) *
+ triton.cdiv(N, META['BLOCK_N']), )
+
+ assert A.dim() == 2
+ assert A_scale.dim() == 2
+ assert B.dim() == 2
+ assert B_scale.dim() == 2
+
+ M, K = A.shape
+ _, N = B.shape
+
+ group_ak = triton.cdiv(K, A_scale.size(1))
+ group_bk = triton.cdiv(K, B_scale.size(0))
+ group_bn = triton.cdiv(N, B_scale.size(1))
+
+ C = A.new_empty(M, N, dtype=out_dtype)
+
+ BLOCK_K = max(group_ak, group_bk)
+
+ _gemm_fp8_kernel[grid](
+ A,
+ A_scale,
+ B,
+ B_scale,
+ C,
+ M=M,
+ N=N,
+ K=K,
+ group_ak=group_ak,
+ group_bk=group_bk,
+ group_bn=group_bn,
+ stride_am=A.stride(0),
+ stride_ak=A.stride(1),
+ stride_asm=A_scale.stride(0),
+ stride_ask=A_scale.stride(1),
+ stride_bk=B.stride(0),
+ stride_bn=B.stride(1),
+ stride_bsk=B_scale.stride(0),
+ stride_bsn=B_scale.stride(1),
+ stride_cm=C.stride(0),
+ stride_cn=C.stride(1),
+ BLOCK_K=BLOCK_K,
+ GROUP_M=8,
+ )
+
+ return C
diff --git a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py
index e2b2091b84..5f59ac4651 100644
--- a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py
+++ b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py
@@ -31,7 +31,7 @@ def _flatten_kv_cache(
stride_vos: tl.constexpr,
stride_vod: tl.constexpr,
stride_boff,
- OUT_SIZE: tl.constexpr,
+ OUT_SIZE,
HEAD_DIM_K: tl.constexpr,
HEAD_DIM_V: tl.constexpr,
BLOCK_BS: tl.constexpr,
@@ -125,7 +125,7 @@ def _flatten_kv_cache_quant(
stride_vod: tl.constexpr,
stride_boff,
quant_policy: tl.constexpr,
- OUT_SIZE: tl.constexpr,
+ OUT_SIZE,
HEAD_DIM_K: tl.constexpr,
HEAD_DIM_V: tl.constexpr,
BLOCK_BS: tl.constexpr,
diff --git a/lmdeploy/pytorch/kernels/cuda/fused_lora.py b/lmdeploy/pytorch/kernels/cuda/fused_lora.py
index d7fbb34588..3dc7e3a10b 100644
--- a/lmdeploy/pytorch/kernels/cuda/fused_lora.py
+++ b/lmdeploy/pytorch/kernels/cuda/fused_lora.py
@@ -9,8 +9,8 @@ def get_autotune_config():
return [
triton.Config(
{
- 'BLOCK_SIZE_M': 64,
- 'BLOCK_SIZE_N': 256,
+ 'BLOCK_SIZE_M': 32,
+ 'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 128
},
num_stages=4,
@@ -26,9 +26,26 @@ def get_autotune_config():
]
+@triton.jit
+def _atomic_store(ptrs, val, mask):
+ """atomic store values."""
+ dtype = ptrs.dtype.element_ty
+ if (dtype == torch.float16) | (dtype == torch.float32):
+ tl.atomic_add(ptrs, val, mask=mask, sem='relaxed')
+ else:
+ # bfloat16 does not support atomic add
+ origin = tl.load(ptrs, mask=mask)
+ val = val.to(origin.dtype)
+ val += origin
+ tl.store(ptrs, val, mask=mask)
+
+
@triton.autotune(
configs=get_autotune_config(),
key=['N', 'K'],
+ restore_value=['c_ptr'],
+ warmup=5,
+ rep=20,
)
@triton.jit
def _fused_lora_kernel(
@@ -44,18 +61,19 @@ def _fused_lora_kernel(
adapter_ids_ptr,
N: tl.constexpr,
K: tl.constexpr,
- stride_am: tl.constexpr,
+ stride_am,
stride_ak: tl.constexpr,
stride_lar: tl.constexpr,
stride_lak: tl.constexpr,
stride_lbr: tl.constexpr,
stride_lbn: tl.constexpr,
- stride_cm: tl.constexpr,
+ stride_cm,
stride_cn: tl.constexpr,
BLOCK_SIZE_R: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
+ CUM: tl.constexpr,
):
"""fused lora kernel."""
pid = tl.program_id(axis=0)
@@ -70,87 +88,91 @@ def _fused_lora_kernel(
rank_start = tl.load(rank_start_ptr + adapter_id)
rank = tl.load(ranks_ptr + adapter_id)
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
- num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
- GROUP_SIZE_M: tl.constexpr = 1
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
- group_id = pid // num_pid_in_group
- first_pid_m = group_id * GROUP_SIZE_M
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
- pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
- pid_n = (pid % num_pid_in_group) // group_size_m
+ pid_m = pid
if pid_m * BLOCK_SIZE_M >= M:
return
offs_m = (seq_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
- offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
mask_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) < M
- if rank == 0:
- offs_cm = offs_m
- offs_cn = offs_n
- c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[
- None, :]
- c_mask = mask_cm[:, None] & (offs_cn[None, :] < N)
- tl.store(c_ptrs, 0, mask=c_mask)
- return
-
- offs_am = (seq_start +
- (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M)
- offs_r = rank_start + tl.arange(0, BLOCK_SIZE_R) % rank
- offs_k = tl.arange(0, BLOCK_SIZE_K)
- a_ptrs = a_ptr + (offs_am[:, None] * stride_am +
- offs_k[None, :] * stride_ak)
- la_ptrs = lora_a_ptr + (offs_k[:, None] * stride_lak +
- offs_r[None, :] * stride_lar)
-
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_R), dtype=tl.float32)
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
- # Load the next block of A and B
- # If it is out of bounds, set it to 0.
- a = tl.load(a_ptrs,
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
- other=0.0)
- la = tl.load(la_ptrs,
- mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
- other=0.0)
- # We accumulate along the K dimension.
- accumulator += tl.dot(a, la)
- # Advance the ptrs to the next K block.
- a_ptrs += BLOCK_SIZE_K * stride_ak
- la_ptrs += BLOCK_SIZE_K * stride_lak
- ar = accumulator.to(lora_b_ptr.dtype.element_ty)
-
- offs_lbn = offs_n % N
- lb_ptrs = lora_b_ptr + (offs_r[:, None] * stride_lbr +
- offs_lbn * stride_lbn)
- lb = tl.load(lb_ptrs, mask=tl.arange(0, BLOCK_SIZE_R)[:, None] < rank)
-
- c = tl.dot(ar, lb)
-
- scaling = tl.load(scaling_ptr + adapter_id)
- c *= scaling
-
- c = c.to(c_ptr.dtype.element_ty)
offs_cm = offs_m
- offs_cn = offs_n
- c_ptrs = c_ptr + stride_cm * offs_cm[:,
- None] + stride_cn * offs_cn[None, :]
- c_mask = mask_cm[:, None] & (offs_cn[None, :] < N)
- tl.store(c_ptrs, c, mask=c_mask)
-
-
-def fused_lora(input: torch.Tensor, lora_a: torch.Tensor, lora_b: torch.Tensor,
- scaling: torch.LongTensor, rank_start: torch.LongTensor,
- ranks: torch.LongTensor, seq_start: torch.LongTensor,
- seq_lens: torch.LongTensor, adapter_ids: torch.LongTensor,
- max_rank: int, max_seqlen: int):
+ c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_n[None, :]
+
+ if rank == 0:
+ if not CUM:
+ for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
+ mask_cn = (offs_n < N - n * BLOCK_SIZE_N)
+ c_mask = mask_cm[:, None] * mask_cn[None, :]
+ tl.store(c_ptrs, 0.0, mask=c_mask)
+ c_ptrs += stride_cn * BLOCK_SIZE_N
+ else:
+
+ offs_am = (seq_start +
+ (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M)
+ offs_r = rank_start + tl.arange(0, BLOCK_SIZE_R) % rank
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am +
+ offs_k[None, :] * stride_ak)
+ la_ptrs = lora_a_ptr + (offs_k[:, None] * stride_lak +
+ offs_r[None, :] * stride_lar)
+
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_R), dtype=tl.float32)
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
+ # Load the next block of A and B
+ # If it is out of bounds, set it to 0.
+ a = tl.load(a_ptrs,
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
+ other=0.0)
+ la = tl.load(la_ptrs,
+ mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
+ other=0.0)
+ # We accumulate along the K dimension.
+ accumulator = tl.dot(a, la, acc=accumulator)
+ # Advance the ptrs to the next K block.
+ a_ptrs += BLOCK_SIZE_K * stride_ak
+ la_ptrs += BLOCK_SIZE_K * stride_lak
+ ar = accumulator.to(lora_b_ptr.dtype.element_ty)
+
+ scaling = tl.load(scaling_ptr + adapter_id).to(ar.dtype)
+ ar *= scaling
+ ar = tl.where(
+ tl.arange(0, BLOCK_SIZE_R)[None, :] < rank, ar, tl.zeros_like(ar))
+ lb_ptrs = lora_b_ptr + (offs_r[:, None] * stride_lbr +
+ offs_n[None, :] * stride_lbn)
+
+ for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
+ lb = tl.load(lb_ptrs, mask=offs_n[None, :] < N - n * BLOCK_SIZE_N)
+ c = tl.dot(ar, lb)
+
+ mask_cn = (offs_n < N - n * BLOCK_SIZE_N)
+ c_mask = mask_cm[:, None] * mask_cn[None, :]
+ if CUM:
+ _atomic_store(c_ptrs, c, mask=c_mask)
+ else:
+ tl.store(c_ptrs, c, mask=c_mask)
+ c_ptrs += stride_cn * BLOCK_SIZE_N
+ lb_ptrs += stride_lbn * BLOCK_SIZE_N
+
+
+def fused_lora(input: torch.Tensor,
+ lora_a: torch.Tensor,
+ lora_b: torch.Tensor,
+ scaling: torch.LongTensor,
+ rank_start: torch.LongTensor,
+ ranks: torch.LongTensor,
+ seq_start: torch.LongTensor,
+ seq_lens: torch.LongTensor,
+ adapter_ids: torch.LongTensor,
+ max_rank: int,
+ max_seqlen: int,
+ output: torch.Tensor = None,
+ cum: bool = False):
"""fused lora."""
def grid(META):
- ret = ((triton.cdiv(max_seqlen, META['BLOCK_SIZE_M']) *
- triton.cdiv(N, META['BLOCK_SIZE_N'])), batch_size)
+ ret = ((triton.cdiv(max_seqlen, META['BLOCK_SIZE_M'])), batch_size)
return ret
assert input.dim() == 2
@@ -158,7 +180,12 @@ def grid(META):
M, K = input.shape
N = lora_b.size(1)
- output = input.new_empty((M, N))
+ if output is None:
+ output = input.new_empty((M, N))
+ cum = False
+ else:
+ assert output.size(0) == M
+ assert output.size(1) == N
BLOCK_SIZE_R = max(16, max_rank)
_fused_lora_kernel[grid](
@@ -183,6 +210,7 @@ def grid(META):
stride_cm=output.stride(0),
stride_cn=output.stride(1),
BLOCK_SIZE_R=BLOCK_SIZE_R,
+ CUM=cum,
)
return output
diff --git a/lmdeploy/pytorch/kernels/cuda/fused_moe.py b/lmdeploy/pytorch/kernels/cuda/fused_moe.py
index 9f9771368e..9d73208c53 100644
--- a/lmdeploy/pytorch/kernels/cuda/fused_moe.py
+++ b/lmdeploy/pytorch/kernels/cuda/fused_moe.py
@@ -91,8 +91,6 @@ def fused_moe_kernel(
if GROUP_SIZE_M == 1:
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
- # pid_m = pid // num_pid_n
- # pid_n = pid % num_pid_n
else:
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
@@ -133,7 +131,7 @@ def fused_moe_kernel(
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
- accumulator += tl.dot(a, b)
+ accumulator = tl.dot(a, b, acc=accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -271,6 +269,33 @@ def get_start_end(topk_idx: torch.Tensor, sorted_idx: torch.Tensor,
return exp_start, exp_end
+def _get_sorted_idx(topk_ids: torch.Tensor, num_experts: int):
+ """get sorted idx."""
+ flatten_topk_ids = topk_ids.flatten()
+ sorted_idx = flatten_topk_ids.argsort()
+
+ exp_start, exp_end = get_start_end(flatten_topk_ids, sorted_idx,
+ num_experts)
+ return sorted_idx, exp_start, exp_end
+
+
+def _renormalize(topk_weights: torch.Tensor, renormalize: bool):
+ if renormalize:
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
+ if not topk_weights.is_contiguous():
+ topk_weights = topk_weights.contiguous()
+ return topk_weights
+
+
+def _make_intermediate(shape: tuple, dtype: torch.dtype, device: torch.device,
+ zeros: bool):
+ """make intermediate."""
+ if zeros:
+ return torch.zeros(shape, dtype=dtype, device=device)
+ else:
+ return torch.empty(shape, dtype=dtype, device=device)
+
+
def fused_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@@ -283,31 +308,17 @@ def fused_moe(hidden_states: torch.Tensor,
"""fused moe."""
M = hidden_states.size(0)
E, N, _ = w1.shape
- full_exp = False
if num_experts is None:
num_experts = E
- elif num_experts == E:
- full_exp = True
-
- def __get_sorted_idx(topk_ids: torch.Tensor):
- flatten_topk_ids = topk_ids.flatten()
- sorted_idx = flatten_topk_ids.argsort()
-
- exp_start, exp_end = get_start_end(flatten_topk_ids, sorted_idx,
- num_experts)
- return sorted_idx, exp_start, exp_end
-
- if renormalize:
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
- if not topk_weights.is_contiguous():
- topk_weights = topk_weights.contiguous()
+ full_exp = num_experts == E
- sorted_idx, exp_start, exp_end = __get_sorted_idx(topk_ids)
+ topk_weights = _renormalize(topk_weights, renormalize)
+ sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts)
- if full_exp:
- intermediate_cache1 = hidden_states.new_empty((M, topk, N))
- else:
- intermediate_cache1 = hidden_states.new_zeros((M, topk, N))
+ intermediate_cache1 = _make_intermediate((M, topk, N),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device,
+ zeros=not full_exp)
# gate and up
fused_moe_kernel_launcher(
hidden_states,
@@ -331,10 +342,10 @@ def __get_sorted_idx(topk_ids: torch.Tensor):
gate_cache = silu_and_mul(intermediate_cache1)
gate_cache = gate_cache.unflatten(0, unflat_size)
- if full_exp:
- intermediate_cache2 = hidden_states.new_empty((M, topk, w2.shape[1]))
- else:
- intermediate_cache2 = hidden_states.new_zeros((M, topk, w2.shape[1]))
+ intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device,
+ zeros=not full_exp)
# down
fused_moe_kernel_launcher(
gate_cache,
diff --git a/lmdeploy/pytorch/kernels/cuda/rms_norm.py b/lmdeploy/pytorch/kernels/cuda/rms_norm.py
index bc994012fc..045b55e1ba 100644
--- a/lmdeploy/pytorch/kernels/cuda/rms_norm.py
+++ b/lmdeploy/pytorch/kernels/cuda/rms_norm.py
@@ -4,8 +4,6 @@
import triton.language as tl
from torch import Tensor
-from .triton_utils import get_kernel_meta, wrap_jit_func
-
@triton.jit
def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):
@@ -18,15 +16,6 @@ def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):
return out
-@wrap_jit_func(type_hint=dict(
- input=Tensor,
- weight=Tensor,
- output=Tensor,
- input_row_stride=int,
- eps=float,
- N_COLS=torch.int32,
- BLOCK_N=torch.int32,
-))
@triton.jit
def rms_norm_kernel(input, weight, output, input_row_stride: tl.constexpr,
eps: tl.constexpr, N_COLS: tl.constexpr,
@@ -45,18 +34,6 @@ def rms_norm_kernel(input, weight, output, input_row_stride: tl.constexpr,
tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)
-@wrap_jit_func(type_hint=dict(
- input=Tensor,
- weight=Tensor,
- residual=Tensor,
- output=Tensor,
- out_residual=Tensor,
- input_row_stride=int,
- residual_row_stride=int,
- eps=float,
- N_COLS=torch.int32,
- BLOCK_N=torch.int32,
-))
@triton.jit
def add_rms_norm_kernel(input, weight, residual, output, out_residual,
input_row_stride: tl.constexpr,
@@ -95,6 +72,7 @@ def rms_norm(hidden_states: Tensor,
hidden_states = hidden_states.contiguous()
feat_size = weight.shape[0]
+ assert hidden_states.size(-1) == feat_size
seq_len = hidden_states.numel() // hidden_states.size(-1)
input_stride = hidden_states.stride(-2)
@@ -103,39 +81,40 @@ def rms_norm(hidden_states: Tensor,
if out is None:
out = torch.empty_like(hidden_states)
- kernel_meta = get_kernel_meta(hidden_states)
grid = (seq_len, )
if residual is None:
- rms_norm_kernel[grid](hidden_states,
- weight,
- out,
- input_row_stride=input_stride,
- eps=eps,
- N_COLS=feat_size,
- BLOCK_N=BLOCK_N,
- num_warps=4,
- num_stages=2,
- **kernel_meta)
+ rms_norm_kernel[grid](
+ hidden_states,
+ weight,
+ out,
+ input_row_stride=input_stride,
+ eps=eps,
+ N_COLS=feat_size,
+ BLOCK_N=BLOCK_N,
+ num_warps=4,
+ num_stages=2,
+ )
return out
else:
if out_residual is None:
out_residual = torch.empty_like(hidden_states)
res_stride = residual.stride(-2)
- add_rms_norm_kernel[grid](hidden_states,
- weight,
- residual,
- out,
- out_residual,
- input_row_stride=input_stride,
- residual_row_stride=res_stride,
- eps=eps,
- N_COLS=feat_size,
- BLOCK_N=BLOCK_N,
- num_warps=4,
- num_stages=2,
- **kernel_meta)
+ add_rms_norm_kernel[grid](
+ hidden_states,
+ weight,
+ residual,
+ out,
+ out_residual,
+ input_row_stride=input_stride,
+ residual_row_stride=res_stride,
+ eps=eps,
+ N_COLS=feat_size,
+ BLOCK_N=BLOCK_N,
+ num_warps=4,
+ num_stages=2,
+ )
return out, out_residual
diff --git a/lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py b/lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py
new file mode 100644
index 0000000000..72d9d802a4
--- /dev/null
+++ b/lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py
@@ -0,0 +1,312 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# modify from: https://github.com/vllm-project/vllm
+import torch
+import triton
+import triton.language as tl
+
+from .activation import silu_and_mul
+from .fused_moe import _get_sorted_idx, _make_intermediate, _renormalize
+from .triton_utils import get_kernel_meta
+from .w8a8_triton_kernels import per_token_quant_int8
+
+
+def get_cuda_autotune_config():
+ return [
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 128,
+ 'BLOCK_SIZE_N': 128,
+ 'BLOCK_SIZE_K': 32,
+ 'GROUP_SIZE_M': 1,
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 256,
+ 'BLOCK_SIZE_K': 32,
+ 'GROUP_SIZE_M': 1,
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 128,
+ 'BLOCK_SIZE_K': 64,
+ 'GROUP_SIZE_M': 1,
+ },
+ num_stages=4,
+ num_warps=4),
+ ]
+
+
+@triton.autotune(
+ configs=get_cuda_autotune_config(),
+ key=['N', 'K', 'M_NP2'],
+ warmup=10,
+ rep=25,
+)
+@triton.jit
+def fused_moe_w8a8_kernel(
+ A,
+ A_scale,
+ B,
+ B_scale,
+ C,
+ SortedIdx,
+ ExpStart,
+ ExpEnd,
+ Weights,
+ N: tl.constexpr,
+ K: tl.constexpr,
+ stride_am: tl.constexpr,
+ stride_ak: tl.constexpr,
+ stride_be: tl.constexpr,
+ stride_bn: tl.constexpr,
+ stride_bk: tl.constexpr,
+ stride_bse: tl.constexpr,
+ stride_cm: tl.constexpr,
+ stride_cn: tl.constexpr,
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr,
+ M_NP2: tl.constexpr,
+ ENABLE_WEIGHTS: tl.constexpr,
+ top_k: tl.constexpr,
+ expert_offset: tl.constexpr,
+ reindex_a: tl.constexpr,
+ reindex_c: tl.constexpr,
+):
+ """fused moe kernel."""
+ exp_id = tl.program_id(1)
+ pid = tl.program_id(0)
+
+ exp_start = tl.load(ExpStart + exp_id + expert_offset)
+ exp_end = tl.load(ExpEnd + exp_id + expert_offset)
+ M = exp_end - exp_start
+ if M <= 0:
+ return
+
+ num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+
+ if GROUP_SIZE_M == 1:
+ pid_m = pid % num_pid_m
+ pid_n = pid // num_pid_m
+ else:
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N:
+ return
+
+ offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ mask_sid = offs_sid < exp_end
+ sid = tl.load(SortedIdx + offs_sid, mask=mask_sid, other=0)
+
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ if reindex_a:
+ offs_am = sid // top_k
+ else:
+ offs_am = offs_sid
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
+ as_ptrs = A_scale + offs_am
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+ offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N),
+ BLOCK_SIZE_N)
+
+ # deepseek has 160 experts, exp index would overflow int32
+ exp_id = exp_id.to(tl.int64)
+ exp_off = stride_be * exp_id
+ b_ptrs = B + exp_off + (offs_k[:, None] * stride_bk +
+ offs_bn[None, :] * stride_bn)
+ bs_ptrs = B_scale + exp_id * stride_bse + offs_bn
+
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
+
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
+ a = tl.load(a_ptrs,
+ mask=mask_sid[:, None] &
+ (offs_k[None, :] < K - k * BLOCK_SIZE_K),
+ other=0.0)
+ b = tl.load(b_ptrs,
+ mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
+ other=0.0)
+ accumulator = tl.dot(a, b, acc=accumulator)
+ a_ptrs += BLOCK_SIZE_K * stride_ak
+ b_ptrs += BLOCK_SIZE_K * stride_bk
+
+ ascale = tl.load(as_ptrs, mask=mask_sid)
+ bscale = tl.load(bs_ptrs)
+ c = accumulator.to(ascale.dtype)
+ c = c * ascale[:, None] * bscale[None, :]
+
+ if ENABLE_WEIGHTS:
+ weight = tl.load(Weights + sid, mask=mask_sid)
+ c = c * weight[:, None].to(c.dtype)
+
+ c = c.to(C.dtype.element_ty)
+
+ if reindex_c:
+ offs_cm = sid
+ else:
+ offs_cm = offs_sid
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_bn[None, :]
+ tl.store(c_ptrs, c, mask=mask_sid[:, None])
+
+
+def fused_moe_w8a8_kernel_launcher(
+ A: torch.Tensor,
+ A_scale: torch.Tensor,
+ B: torch.Tensor,
+ B_scale: torch.Tensor,
+ C: torch.Tensor,
+ sorted_idx: torch.Tensor,
+ exp_start: torch.Tensor,
+ exp_end: torch.Tensor,
+ weights: torch.Tensor,
+ enable_weights: bool = False,
+ top_k: int = 1,
+ num_tokens: int = None,
+ expert_offset: int = 0,
+ reindex_a: bool = True,
+ reindex_c: bool = True,
+):
+ """fused moe kernel launcher."""
+
+ if num_tokens is None:
+ num_tokens = A.size(0)
+ M_NP2 = triton.next_power_of_2(num_tokens)
+ M_NP2 = max(64, M_NP2)
+ E, N, K = B.shape
+
+ assert A_scale.is_contiguous()
+ assert B_scale.is_contiguous()
+
+ def _grid_fn(META):
+ grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) *
+ triton.cdiv(N, META['BLOCK_SIZE_N']), E)
+ return grid
+
+ A = A.flatten(0, -2)
+ C = C.flatten(0, -2)
+
+ grid = _grid_fn
+ kernel_meta = get_kernel_meta(A)
+ fused_moe_w8a8_kernel[grid](
+ A,
+ A_scale,
+ B,
+ B_scale,
+ C,
+ sorted_idx,
+ exp_start,
+ exp_end,
+ weights,
+ N=N,
+ K=K,
+ stride_am=A.stride(0),
+ stride_ak=A.stride(1),
+ stride_be=B.stride(0),
+ stride_bn=B.stride(1),
+ stride_bk=B.stride(2),
+ stride_bse=B_scale.stride(0),
+ stride_cm=C.stride(0),
+ stride_cn=C.stride(1),
+ ENABLE_WEIGHTS=enable_weights,
+ top_k=top_k,
+ expert_offset=expert_offset,
+ reindex_a=reindex_a,
+ reindex_c=reindex_c,
+ M_NP2=M_NP2,
+ **kernel_meta,
+ )
+
+
+def fused_moe_w8a8(input: torch.Tensor,
+ input_scale: torch.Tensor,
+ w1: torch.Tensor,
+ w1_scale: torch.Tensor,
+ w2: torch.Tensor,
+ w2_scale: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ topk: int,
+ out_dtype: torch.dtype = torch.float16,
+ expert_offset: int = 0,
+ num_experts: int = None,
+ renormalize: bool = False) -> torch.Tensor:
+ """fused moe."""
+ device = input.device
+ M = input.size(0)
+ E, N, _ = w1.shape
+ if num_experts is None:
+ num_experts = E
+ full_exp = num_experts == E
+
+ topk_weights = _renormalize(topk_weights, renormalize)
+ sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts)
+
+ intermediate_cache1 = _make_intermediate((M, topk, N),
+ dtype=out_dtype,
+ device=device,
+ zeros=not full_exp)
+ # gate and up
+ fused_moe_w8a8_kernel_launcher(
+ input,
+ input_scale,
+ w1,
+ w1_scale,
+ intermediate_cache1,
+ sorted_idx=sorted_idx,
+ exp_start=exp_start,
+ exp_end=exp_end,
+ weights=topk_weights,
+ enable_weights=False,
+ top_k=topk,
+ num_tokens=M,
+ expert_offset=expert_offset,
+ reindex_a=True,
+ reindex_c=False,
+ )
+
+ # activate
+ unflat_size = intermediate_cache1.shape[:-1]
+ intermediate_cache1 = intermediate_cache1.flatten(0, -2)
+ gate_cache = silu_and_mul(intermediate_cache1)
+ del intermediate_cache1
+ gate_cache = gate_cache.unflatten(0, unflat_size)
+ gate_cache, gate_scale = per_token_quant_int8(gate_cache, 1e-7)
+
+ intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]),
+ dtype=out_dtype,
+ device=device,
+ zeros=not full_exp)
+ # down
+ fused_moe_w8a8_kernel_launcher(
+ gate_cache,
+ gate_scale,
+ w2,
+ w2_scale,
+ intermediate_cache2,
+ sorted_idx=sorted_idx,
+ exp_start=exp_start,
+ exp_end=exp_end,
+ weights=topk_weights,
+ enable_weights=True,
+ top_k=1,
+ num_tokens=M,
+ expert_offset=expert_offset,
+ reindex_a=False,
+ reindex_c=True,
+ )
+
+ ret = intermediate_cache2.sum(dim=1)
+ return ret
diff --git a/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py b/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py
index 0d0e10ec83..a8eeb63a5f 100644
--- a/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py
+++ b/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py
@@ -14,14 +14,13 @@
tl_round = tl.math.round
-def per_channel_quant(x, n_bits, dtype):
+def per_channel_quant(x: torch.Tensor, dtype: torch.dtype):
"""Quantize the input tensor 'x' channel-wise using the given number of
bits.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be a
2-dimensional tensor.
- n_bits (int): The number of bits to use for quantization.
dtype (torch.dtype): The data type to which the quantized tensor should
be converted.
@@ -32,31 +31,40 @@ def per_channel_quant(x, n_bits, dtype):
assert x.ndim == 2
x = x.to(torch.float32)
x_absmax = x.view(x.shape[0], -1).abs().max(dim=1, keepdim=True)[0]
- q_max = 2**(n_bits - 1) - 1
- q_min = -2**(n_bits - 1)
- scale = x_absmax / (2**(n_bits - 1) - 1)
- x_q = torch.round(x / scale).clamp(q_min, q_max).to(dtype)
+ qtype_info = torch.finfo(
+ dtype) if dtype.is_floating_point else torch.iinfo(dtype)
+ q_max = qtype_info.max
+ q_min = qtype_info.min
+ scale = x_absmax / q_max
+ x_q = x / scale
+ if not dtype.is_floating_point:
+ x_q = torch.round(x_q)
+ x_q = x_q.clamp(q_min, q_max).to(dtype)
return x_q, scale
@triton.autotune(
configs=[
triton.Config({
- 'BLOCK_N': 64,
+ 'BLOCK_M': 128,
+ 'BLOCK_N': 256,
'BLOCK_K': 128,
},
- num_stages=4,
- num_warps=4),
+ num_stages=3,
+ num_warps=8),
triton.Config({
+ 'BLOCK_M': 256,
'BLOCK_N': 128,
'BLOCK_K': 128,
},
- num_stages=4,
- num_warps=4)
+ num_stages=3,
+ num_warps=8)
],
key=['N', 'K'],
+ warmup=5,
+ rep=20,
)
-@triton.jit
+@triton.jit(do_not_specialize=['M'])
def _linear(
A,
B,
@@ -76,6 +84,7 @@ def _linear(
GROUP_SIZE_M: tl.constexpr,
rms_scale_ptr,
linear_scale_ptr,
+ ACCUMULATOR_DTYPE: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B`, and store the result in output
@@ -100,12 +109,11 @@ def _linear(
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
-
- accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
+ accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACCUMULATOR_DTYPE)
for k in range(0, tl.cdiv(K, BLOCK_K)):
- a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0)
- b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0)
- accumulator += tl.dot(a, b)
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=None)
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=None)
+ accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = accumulator.to(tl.float32)
@@ -124,42 +132,31 @@ def _linear(
@triton.autotune(
configs=[
triton.Config({
- 'BLOCK_N': 64,
+ 'BLOCK_M': 128,
+ 'BLOCK_N': 256,
'BLOCK_K': 128,
},
- num_stages=4,
- num_warps=4),
+ num_stages=3,
+ num_warps=8),
triton.Config({
+ 'BLOCK_M': 256,
'BLOCK_N': 128,
'BLOCK_K': 128,
},
- num_stages=4,
- num_warps=4)
+ num_stages=3,
+ num_warps=8)
],
key=['N', 'K'],
+ warmup=5,
+ rep=20,
)
-@triton.jit
-def _linear_add(
- A,
- B,
- C,
- residual_ptr,
- M,
- N,
- K,
- stride_am,
- stride_ak,
- stride_bk,
- stride_bn,
- stride_cm,
- stride_cn,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
- GROUP_SIZE_M: tl.constexpr,
- rms_scale_ptr,
- linear_scale_ptr,
-):
+@triton.jit(do_not_specialize=['M'])
+def _linear_add(A, B, C, residual_ptr, M, N, K, stride_am, stride_ak,
+ stride_bk, stride_bn, stride_cm, stride_cn,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
+ rms_scale_ptr, linear_scale_ptr,
+ ACCUMULATOR_DTYPE: tl.constexpr):
"""Triton-accelerated function used to perform a linear operation (dot
product) on input tensors `A` and `B`, with addition of residual.
@@ -183,11 +180,11 @@ def _linear_add(
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
- accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
+ accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACCUMULATOR_DTYPE)
for k in range(0, tl.cdiv(K, BLOCK_K)):
- a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0)
- b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0)
- accumulator += tl.dot(a, b)
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=None)
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=None)
+ accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = accumulator.to(tl.float32)
@@ -231,14 +228,11 @@ def matmul_kernel_dynamic_quant(a,
assert residual.shape == c_shape
assert residual.is_contiguous()
c = a.new_empty(c_shape, dtype=output_dtype)
-
- BLOCK_M = 128
- if M < BLOCK_M:
- BLOCK_M = triton.next_power_of_2(M)
- BLOCK_M = max(BLOCK_M, 16)
+ accumulator_dtype = tl.float32 if a.is_floating_point() else tl.int32
def grid(META):
- return (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, META['BLOCK_N']), )
+ return (triton.cdiv(M, META['BLOCK_M']) *
+ triton.cdiv(N, META['BLOCK_N']), )
kernel_meta = get_kernel_meta(a)
if residual is not None:
@@ -255,10 +249,10 @@ def grid(META):
b.stride(0),
c.stride(-2),
c.stride(-1),
- BLOCK_M=BLOCK_M,
GROUP_SIZE_M=8,
rms_scale_ptr=rms_scale,
linear_scale_ptr=linear_scale,
+ ACCUMULATOR_DTYPE=accumulator_dtype,
**kernel_meta)
else:
_linear[grid](a,
@@ -273,10 +267,10 @@ def grid(META):
b.stride(0),
c.stride(-2),
c.stride(-1),
- BLOCK_M=BLOCK_M,
GROUP_SIZE_M=8,
rms_scale_ptr=rms_scale,
linear_scale_ptr=linear_scale,
+ ACCUMULATOR_DTYPE=accumulator_dtype,
**kernel_meta)
if bias is not None:
c += bias
@@ -286,13 +280,16 @@ def grid(META):
@triton.jit
def _per_token_quant_int8(
- y_ptr,
- y_q_ptr,
- y_s_ptr,
- y_stride,
- N, # number of columns in X
- eps, # epsilon to avoid division by zero
- BLOCK: tl.constexpr,
+ y_ptr,
+ y_q_ptr,
+ y_s_ptr,
+ y_stride: tl.constexpr,
+ yq_stride: tl.constexpr,
+ N, # number of columns in X
+ eps: tl.constexpr, # epsilon to avoid division by zero
+ BLOCK: tl.constexpr,
+ Q_MAX: tl.constexpr,
+ IS_FLOATING_POINT: tl.constexpr, # True for floating point dtype
):
"""A Triton-accelerated function to perform per-token quantization on a
tensor.
@@ -302,7 +299,7 @@ def _per_token_quant_int8(
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
y_ptr += row * y_stride
- y_q_ptr += row * y_stride
+ y_q_ptr += row * yq_stride
y_s_ptr += row
cols = tl.arange(0, BLOCK) # N <= BLOCK
@@ -311,21 +308,26 @@ def _per_token_quant_int8(
y = tl.load(y_ptr + cols, mask=mask, other=0.).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
- y_s = _absmax / 127
- y_q = tl_round(y / y_s).to(tl.int8)
+ y_s = _absmax / Q_MAX
+ y_q = y / y_s
+ if not IS_FLOATING_POINT:
+ y_q = tl_round(y_q).to(tl.int8)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
-def per_token_quant_int8(x, eps):
+def per_token_quant_int8(x, eps, quant_dtype=torch.int8):
"""Function to perform per-token quantization on an input tensor `x`.
It converts the tensor values into signed 8-bit integers and returns the
quantized tensor along with the scaling factor used for quantization.
"""
-
- x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
+ qdtype_info = torch.finfo(
+ quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(
+ quant_dtype)
+ q_max = qdtype_info.max
+ x_q = torch.empty_like(x, device=x.device, dtype=quant_dtype)
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_s = torch.empty(x.shape[:-1] + (1, ),
@@ -334,94 +336,184 @@ def per_token_quant_int8(x, eps):
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
+
+ if x.dim() > 2:
+ x = x.flatten(0, -2)
+ assert x.stride(-1) == 1
# enqueue kernel
kernel_meta = get_kernel_meta(x)
- _per_token_quant_int8[(M, )](x,
- x_q,
- x_s,
- x.stride(-2),
- N,
- eps,
- BLOCK=BLOCK,
- num_warps=num_warps,
- **kernel_meta)
+ _per_token_quant_int8[(M, )](
+ x,
+ x_q,
+ x_s,
+ y_stride=x.stride(-2),
+ yq_stride=x_q.stride(-2),
+ N=N,
+ eps=eps,
+ BLOCK=BLOCK,
+ Q_MAX=q_max,
+ IS_FLOATING_POINT=quant_dtype.is_floating_point,
+ num_warps=num_warps,
+ **kernel_meta)
return x_q, x_s
@triton.jit
-def _rms_norm_fwd_fused_dynamic_symmetric(
- X, # pointer to the input
- Y, # pointer to the output
- W, # pointer to the weights
- Scale, # pointer to the scales of the output activation
- stride, # how much to increase the pointer when moving by 1 row
- N, # number of columns in X
- eps, # epsilon to avoid division by zero
- BLOCK_SIZE: tl.constexpr,
+def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):
+ """compute rms norm."""
+ xf = x.to(tl.float32)
+
+ var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)
+ out = xf * tl.math.rsqrt(var + eps)
+ out = (w * out).to(x.dtype)
+ return out
+
+
+@triton.jit
+def rms_norm_quant_kernel(
+ input,
+ weight,
+ output,
+ out_scale,
+ input_row_stride: tl.constexpr,
+ eps: tl.constexpr,
+ N_COLS: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ Q_MIN: tl.constexpr,
+ Q_MAX: tl.constexpr,
+ IS_FLOATING_POINT: tl.constexpr,
):
- """A Triton kernel that calculates Root Mean Square (RMS) normalization
- with fused dynamic symmetric quantization."""
- row = tl.program_id(0)
- Y += row * stride
- X += row * stride
+ """rms norm kernel."""
+ prog_id = tl.program_id(0)
+ offsets = tl.arange(0, BLOCK_N)
- cols = tl.arange(0, BLOCK_SIZE)
- mask = cols < N
- x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
- _var = x * x
- var = tl.sum(_var, axis=0) / N
- rstd = tl.math.rsqrt(var + eps)
+ w = tl.load(weight + offsets, mask=offsets < N_COLS)
+
+ x_ptr = input + prog_id * input_row_stride
+ x = tl.load(x_ptr + offsets, mask=offsets < N_COLS)
+ out = _compute_rms_norm(x, w, eps, N_COLS)
+
+ scale = tl.max(tl.abs(out)).to(tl.float32) / Q_MAX
+ out_s_ptr = out_scale + prog_id
+ tl.store(out_s_ptr, scale)
+ out = out / scale
+ if not IS_FLOATING_POINT:
+ out = tl_round(out)
+ out = tl.clamp(out, Q_MIN, Q_MAX)
+ out_ptr = output + prog_id * input_row_stride
+ tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)
- w = tl.load(W + cols, mask=mask)
- x_hat = x * rstd
- y = x_hat * w
- scale = tl.max(tl.abs(y)).to(tl.float32) / 127
- tl.store(Scale + row, scale)
+@triton.jit
+def add_rms_norm_quant_kernel(
+ input,
+ weight,
+ residual,
+ output,
+ out_scale,
+ out_residual,
+ input_row_stride: tl.constexpr,
+ residual_row_stride: tl.constexpr,
+ eps: tl.constexpr,
+ N_COLS: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ Q_MIN: tl.constexpr,
+ Q_MAX: tl.constexpr,
+ IS_FLOATING_POINT: tl.constexpr,
+):
+ """rms norm kernel."""
+ prog_id = tl.program_id(0)
+ offsets = tl.arange(0, BLOCK_N)
+
+ w = tl.load(weight + offsets, mask=offsets < N_COLS)
- y = tl_round(y / scale)
- y = tl.minimum(y, 127)
- y = tl.maximum(y, -128)
- tl.store(Y + cols, y, mask=mask)
+ x_ptr = input + prog_id * input_row_stride
+ x = tl.load(x_ptr + offsets, mask=offsets < N_COLS)
+ res_ptr = residual + prog_id * residual_row_stride
+ res = tl.load(res_ptr + offsets, mask=offsets < N_COLS)
-def rms_norm_dynamic_quant(x, w, eps):
+ new_x = x + res
+ out_res_ptr = out_residual + prog_id * residual_row_stride
+ tl.store(out_res_ptr + offsets, new_x, mask=offsets < N_COLS)
+
+ out = _compute_rms_norm(new_x, w, eps, N_COLS)
+
+ scale = tl.max(tl.abs(out)).to(tl.float32) / Q_MAX
+ out_s_ptr = out_scale + prog_id
+ tl.store(out_s_ptr, scale)
+ out = out / scale
+ if not IS_FLOATING_POINT:
+ out = tl_round(out)
+ out = tl.clamp(out, Q_MIN, Q_MAX)
+ out_ptr = output + prog_id * input_row_stride
+ tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)
+
+
+def rms_norm_dynamic_quant(x, w, eps, residual=None, quant_dtype=torch.int8):
"""Performs RMS normalization with dynamic quantization.
The function reshapes the input tensor `x`, creates an empty tensor `y`
with the same shape as `x`, and calculates RMS normalization on the
- reshaped `x` using a Triton kernel `_rms_norm_fwd_fused_dynamic_symmetric`.
+ reshaped `x` using a Triton kernel `rms_norm_quant_kernel`.
"""
-
- x_arg = x.flatten(0, -2)
- y = torch.empty_like(x, dtype=torch.int8)
- M, K = x_arg.shape
- MAX_FUSED_SIZE = 65536 // x.element_size()
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(K))
- if K > BLOCK_SIZE:
- raise RuntimeError(
- "This rms norm doesn't support feature dim >= 64KB.")
- num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
+ qdtype_info = torch.finfo(
+ quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(
+ quant_dtype)
+ y = torch.empty_like(x, dtype=quant_dtype)
scale = x.new_empty(x.shape[:-1] + (1, ), dtype=torch.float32)
- kernel_meta = get_kernel_meta(x_arg)
- _rms_norm_fwd_fused_dynamic_symmetric[(M, )](x_arg,
- y,
- w,
- scale,
- x_arg.stride(0),
- K,
- eps,
- BLOCK_SIZE=BLOCK_SIZE,
- num_warps=num_warps,
- **kernel_meta)
- return y, scale
+
+ feat_size = w.shape[0]
+ seq_len = x.numel() // x.size(-1)
+ input_stride = x.stride(-2)
+ BLOCK_N = triton.next_power_of_2(feat_size)
+ grid = (seq_len, )
+
+ if residual is None:
+ rms_norm_quant_kernel[grid](
+ x,
+ w,
+ y,
+ scale,
+ input_row_stride=input_stride,
+ eps=eps,
+ N_COLS=feat_size,
+ BLOCK_N=BLOCK_N,
+ Q_MIN=qdtype_info.min,
+ Q_MAX=qdtype_info.max,
+ IS_FLOATING_POINT=quant_dtype.is_floating_point,
+ num_warps=4,
+ num_stages=2)
+ return y, scale
+ else:
+ out_residual = torch.empty_like(x)
+ res_stride = residual.stride(-2)
+ add_rms_norm_quant_kernel[grid](
+ x,
+ w,
+ residual,
+ y,
+ scale,
+ out_residual,
+ input_row_stride=input_stride,
+ residual_row_stride=res_stride,
+ eps=eps,
+ N_COLS=feat_size,
+ BLOCK_N=BLOCK_N,
+ Q_MIN=qdtype_info.min,
+ Q_MAX=qdtype_info.max,
+ IS_FLOATING_POINT=quant_dtype.is_floating_point,
+ num_warps=4,
+ num_stages=2)
+ return y, scale, out_residual
def test_rms_and_linear(x,
rms_weight,
linear_weight,
- dtype=torch.float16,
+ output_dtype=torch.float16,
+ quant_dtype=torch.int8,
eps=1e-5):
"""Test quantized rms norm and quantized linear layer."""
@@ -434,15 +526,18 @@ def linear_torch(x, b):
return F.linear(x, b)
linear_weight_quant, linear_scale = per_channel_quant(
- linear_weight, 8, torch.int8)
+ linear_weight, quant_dtype)
- rms_out, rms_scale = rms_norm_dynamic_quant(x, rms_weight, eps)
+ rms_out, rms_scale = rms_norm_dynamic_quant(x,
+ rms_weight,
+ eps,
+ quant_dtype=quant_dtype)
assert rms_out.shape == x.shape and rms_scale.shape[:-1] == x.shape[:-1]
linear_out = matmul_kernel_dynamic_quant(rms_out,
linear_weight_quant,
rms_scale,
linear_scale,
- output_dtype=dtype)
+ output_dtype=output_dtype)
rms_out_torch = rms_norm_torch(x, rms_weight, eps).half()
linear_out_torch = linear_torch(rms_out_torch, linear_weight)
@@ -456,17 +551,26 @@ def linear_torch(x, b):
linear_out_torch.flatten().to(torch.float32)))
-def test_per_token_quant(x, eps):
+def test_per_token_quant(x, eps, quant_dtype=torch.int8):
"""Test per-token quantization."""
- def per_token_quant_int8_torch(x, eps):
+ def per_token_quant_int8_torch(x, eps, quant_dtype):
+ qdtype_info = torch.finfo(
+ quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(
+ quant_dtype)
+
_absmax = torch.clamp(x.abs().max(dim=-1, keepdim=True)[0], min=eps)
- x_s = _absmax / 127
- x_q = torch.clamp((x / x_s).round(), min=-128, max=127)
+ x_s = _absmax / qdtype_info.max
+ x_q = x / x_s
+ if not quant_dtype.is_floating_point:
+ x_q = x_q.round()
+ x_q = torch.clamp(x_q, min=qdtype_info.min, max=qdtype_info.max)
return x_q, x_s
- x_q, x_s = per_token_quant_int8(x, eps)
- x_q_torch, x_s_torch = per_token_quant_int8_torch(x, eps)
+ x_q, x_s = per_token_quant_int8(x, eps, quant_dtype=quant_dtype)
+ x_q_torch, x_s_torch = per_token_quant_int8_torch(x,
+ eps,
+ quant_dtype=quant_dtype)
assert x_q.shape == x_q_torch.shape and x_s.shape == x_s_torch.shape
cos = torch.nn.CosineSimilarity(0)
print(
@@ -479,21 +583,11 @@ def per_token_quant_int8_torch(x, eps):
x_s_torch.flatten().to(torch.float32)))
-@triton.testing.perf_report(
- triton.testing.Benchmark(
- x_names=['M'],
- x_vals=[1, 16, 32, 64, 128, 256] + [512 * i * 2 for i in range(1, 17)],
- line_arg='provider',
- line_vals=['int8_dynamic_triton_op', 'float_torch'],
- line_names=['int8_dynamic_triton_op', 'float_torch'],
- styles=[('blue', '-'), ('green', '-'), ('orange', '-'),
- ('yellow', '-'), ('yellow', '-')],
- ylabel='GB/s',
- plot_name='forward',
- args={
- 'dtype': torch.float16,
- }))
-def bench_rms_and_linear(M, dtype, provider, eps=1e-5, device='cuda'):
+def bench_rms_and_linear(M: int,
+ provider: str,
+ dtype: torch.dtype = torch.float16,
+ eps: float = 1e-5):
+ """benchmark rms and linear."""
def rms_norm_torch(x, w, eps):
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
@@ -505,6 +599,7 @@ def linear_torch(x, b):
N = 4096
K = 4096
+
x_shape = (M, K)
rms_w_shape = (x_shape[-1], )
rms_weight = torch.randn(rms_w_shape,
@@ -516,14 +611,33 @@ def linear_torch(x, b):
dtype=dtype,
device='cuda',
requires_grad=True)
- linear_weight_quant, linear_scale = per_channel_quant(
- linear_weight, 8, torch.int8)
- alpha = max(x.max().abs(), x.min().abs())
- rms_scale = alpha / 127
+ if provider == 'torch_fp16':
+ rms_out_torch = rms_norm_torch(x, rms_weight, eps).half()
- if provider == 'int8_dynamic_triton_op':
- rms_out, rms_scale = rms_norm_dynamic_quant(x, rms_weight, eps)
+ def y_fwd():
+ linear_torch(rms_out_torch, linear_weight)
+ else:
+ if provider == 'triton_int8':
+ quant_dtype = torch.int8
+ elif provider == 'triton_fp8_e4m3':
+ quant_dtype = torch.float8_e4m3fn
+ elif provider == 'triton_fp8_e5m2':
+ quant_dtype = torch.float8_e5m2
+
+ linear_weight_quant, linear_scale = per_channel_quant(
+ linear_weight, quant_dtype)
+
+ alpha = max(x.max().abs(), x.min().abs())
+ if quant_dtype.is_floating_point:
+ qdtype_info = torch.finfo(quant_dtype)
+ else:
+ qdtype_info = torch.iinfo(quant_dtype)
+ rms_scale = alpha / qdtype_info.max
+ rms_out, rms_scale = rms_norm_dynamic_quant(x,
+ rms_weight,
+ eps,
+ quant_dtype=quant_dtype)
def y_fwd():
@@ -532,21 +646,22 @@ def y_fwd():
rms_scale,
linear_scale,
output_dtype=dtype)
- elif provider == 'float_torch':
- rms_out_torch = rms_norm_torch(x, rms_weight, eps).half()
-
- def y_fwd():
- linear_torch(rms_out_torch, linear_weight)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd,
quantiles=quantiles,
rep=500)
- return ms, max_ms, min_ms
+
+ def perf(ms):
+ return 2 * M * N * K * 1e-12 / (ms * 1e-3)
+
+ return perf(ms), perf(max_ms), perf(min_ms)
if __name__ == '__main__':
torch.manual_seed(0)
+ device_map = torch.cuda.get_device_capability()
+ is_fp8_supported = device_map[0] >= 9
dtype = torch.float16
# test (bs, seq_len, dim) x (dim, out_dim)
x = torch.randn((2, 2048, 4096), dtype=dtype, device='cuda')
@@ -559,7 +674,16 @@ def y_fwd():
dtype=dtype,
device='cuda',
requires_grad=True)
- test_rms_and_linear(x, rms_weight, linear_weight)
+ test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.int8)
+ if is_fp8_supported:
+ test_rms_and_linear(x,
+ rms_weight,
+ linear_weight,
+ quant_dtype=torch.float8_e4m3fn)
+ test_rms_and_linear(x,
+ rms_weight,
+ linear_weight,
+ quant_dtype=torch.float8_e5m2)
# test (M, K) x (K, N)
x = torch.randn((4, 4096), dtype=dtype, device='cuda')
@@ -572,11 +696,45 @@ def y_fwd():
dtype=dtype,
device='cuda',
requires_grad=True)
- test_rms_and_linear(x, rms_weight, linear_weight)
+ test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.int8)
+ if is_fp8_supported:
+ test_rms_and_linear(x,
+ rms_weight,
+ linear_weight,
+ quant_dtype=torch.float8_e4m3fn)
+ test_rms_and_linear(x,
+ rms_weight,
+ linear_weight,
+ quant_dtype=torch.float8_e5m2)
# test per-token quant
x = torch.randn((4, 2048, 4096), dtype=dtype, device='cuda')
eps = 1e-7
- test_per_token_quant(x, eps)
-
- bench_rms_and_linear.run(print_data=True)
+ test_per_token_quant(x, eps, quant_dtype=torch.int8)
+ if is_fp8_supported:
+ test_per_token_quant(x, eps, quant_dtype=torch.float8_e4m3fn)
+ test_per_token_quant(x, eps, quant_dtype=torch.float8_e5m2)
+
+ # benchmark triton kernels
+ line_vals = ['triton_int8', 'torch_fp16']
+ line_names = ['triton_int8', 'torch_fp16']
+
+ if is_fp8_supported:
+ line_vals += ['triton_fp8_e4m3', 'triton_fp8_e5m2']
+ line_names += ['triton_fp8_e4m3', 'triton_fp8_e5m2']
+ config = triton.testing.Benchmark(x_names=['M'],
+ x_vals=[1, 16, 32, 64, 128, 256] +
+ [512 * i * 2 for i in range(1, 5)],
+ line_arg='provider',
+ line_vals=line_vals,
+ line_names=line_names,
+ styles=[('blue', '-'), ('green', '-'),
+ ('orange', '-'), ('black', '-'),
+ ('yellow', '-')],
+ ylabel='TFLOPS',
+ plot_name='bench-triton',
+ args={
+ 'dtype': torch.float16,
+ })
+ bench_funch = (triton.testing.perf_report(config))(bench_rms_and_linear)
+ bench_funch.run(print_data=True)
diff --git a/lmdeploy/pytorch/kernels/dlinfer/__init__.py b/lmdeploy/pytorch/kernels/dlinfer/__init__.py
index 8f86f0019a..fe82010761 100644
--- a/lmdeploy/pytorch/kernels/dlinfer/__init__.py
+++ b/lmdeploy/pytorch/kernels/dlinfer/__init__.py
@@ -3,6 +3,7 @@
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .awq_kernels import awq_linear
from .fill_kv_cache import fill_kv_cache
+from .flash_attention import flash_attention_fwd
from .fused_moe import fused_moe
from .linear import linear
from .moe_gating_topk_softmax import moe_gating_topk_softmax
@@ -16,6 +17,7 @@
'fill_kv_cache',
'fused_moe',
'paged_attention_fwd',
+ 'flash_attention_fwd',
'linear',
'moe_gating_topk_softmax',
'multinomial_sampling',
diff --git a/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py b/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py
index fb2eee9d41..63564d7ed8 100644
--- a/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py
+++ b/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Sequence
+
import dlinfer.ops as ext_ops
from torch import Tensor
@@ -9,7 +11,16 @@ def fill_kv_cache(
key_caches: Tensor,
value_caches: Tensor,
kv_start_indices: Tensor,
+ k_scales_zeros: Sequence[Optional[Tensor]],
+ v_scales_zeros: Sequence[Optional[Tensor]],
+ quant_bits: int = 0,
):
"""fill key/value state to cache for paged attention."""
- return ext_ops.fill_kv_cache(key_states, value_states, key_caches,
- value_caches, kv_start_indices)
+ return ext_ops.fill_kv_cache(key_states,
+ value_states,
+ key_caches,
+ value_caches,
+ kv_start_indices,
+ k_scales_zeros=k_scales_zeros,
+ v_scales_zeros=v_scales_zeros,
+ quant_bits=quant_bits)
diff --git a/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py b/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py
new file mode 100644
index 0000000000..1788f947ee
--- /dev/null
+++ b/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import dlinfer.ops as ext_ops
+from dlinfer.utils.type_annotation import Tensor
+
+
+def flash_attention_fwd(
+ query_states: Tensor,
+ key_states: Tensor,
+ value_states: Tensor,
+ attn_output: Tensor,
+ q_start_loc: Tensor,
+ q_seqlens: Tensor,
+ kv_start_loc: Tensor,
+ kv_seqlens: Tensor,
+ max_q_seqlen: int = None,
+ window_size: int = None,
+ sm_scale: float = None,
+ logit_softcapping: float = None,
+ causal: bool = True,
+):
+ num_q_heads = query_states.shape[1]
+ num_kv_heads = value_states.shape[1]
+ return ext_ops.prefill_attention(
+ query_states,
+ key_states,
+ value_states,
+ q_start_loc,
+ q_seqlens,
+ max_q_seqlen,
+ num_q_heads,
+ num_kv_heads,
+ attn_mask=None,
+ softmax_scale=sm_scale,
+ attn_output=attn_output,
+ )
diff --git a/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py b/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py
index 72bab2d720..275ea65261 100644
--- a/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py
+++ b/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py
@@ -5,12 +5,13 @@
def fused_moe(
hidden_states: Tensor,
- top_k: int,
- topk_ids: Tensor,
- topk_weights: Tensor,
gate_up_weights: Tensor,
down_weights: Tensor,
+ topk_weights: Tensor,
+ topk_ids: Tensor,
+ topk: int,
+ renormalize: bool,
):
- """ascend fused moe."""
- return ext_ops.fused_moe(hidden_states, top_k, topk_ids, topk_weights,
- gate_up_weights, down_weights)
+ """dlinfer fused moe."""
+ return ext_ops.fused_moe(hidden_states, gate_up_weights, down_weights,
+ topk_weights, topk_ids, topk, renormalize)
diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
index 47bcb0cfff..ded85d476d 100644
--- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
+++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
@@ -19,6 +19,9 @@ def prefill_attention(
block_size: int,
attn_mask: Sequence[Optional[Tensor]],
is_unpaged_prefill: Optional[bool],
+ kv_scales: Optional[Tensor],
+ kv_zeros: Optional[Tensor],
+ quant_bits: Optional[int],
) -> Tensor:
num_q_heads = query_states.shape[1]
num_kv_heads = value_states.shape[1]
@@ -53,11 +56,25 @@ def prefill_attention(
num_kv_heads,
attn_mask,
attn_output=attn_output,
+ kv_scales=kv_scales,
+ kv_zeros=kv_zeros,
+ quant_bits=quant_bits,
)
-def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
- max_kv_seq_len, block_offsets, block_size):
+def paged_token_attention(
+ q,
+ k_cache,
+ v_cache,
+ attn_output,
+ kv_seq_len,
+ max_kv_seq_len,
+ block_offsets,
+ block_size,
+ kv_scales: Optional[Tensor],
+ kv_zeros: Optional[Tensor],
+ quant_bits: Optional[int],
+):
num_q_heads, q_head_dim = q.shape[1:3]
num_kv_heads = k_cache.shape[-1] // q_head_dim
return ext_ops.paged_decode_attention(
@@ -71,6 +88,9 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
num_q_heads,
num_kv_heads,
attn_output=attn_output,
+ kv_scales=kv_scales,
+ kv_zeros=kv_zeros,
+ quant_bits=quant_bits,
)
@@ -91,6 +111,9 @@ def paged_attention_fwd(
block_size: int,
attn_mask: Sequence[Optional[Tensor]] = (),
is_unpaged_prefill: Optional[bool] = None,
+ kv_scales: Optional[Tensor] = None,
+ kv_zeros: Optional[Tensor] = None,
+ quant_bits: Optional[int] = 0,
):
if not is_decoding:
return prefill_attention(
@@ -108,6 +131,9 @@ def paged_attention_fwd(
block_size,
attn_mask,
is_unpaged_prefill,
+ kv_scales=kv_scales,
+ kv_zeros=kv_zeros,
+ quant_bits=quant_bits,
)
else:
return paged_token_attention(
@@ -119,4 +145,7 @@ def paged_attention_fwd(
max_kv_seq_len,
block_offsets,
block_size,
+ kv_scales=kv_scales,
+ kv_zeros=kv_zeros,
+ quant_bits=quant_bits,
)
diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py
index b16a78f1f4..968b71fee1 100644
--- a/lmdeploy/pytorch/messages.py
+++ b/lmdeploy/pytorch/messages.py
@@ -8,6 +8,7 @@
from torch import Tensor
from lmdeploy.messages import GenerationConfig, LogitsProcessor
+from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs
from lmdeploy.utils import get_logger
from .block import LogicalTokenBlocks
@@ -205,10 +206,9 @@ def add_sequence(
sampling_param: SamplingParam = None,
adapter_name: str = None,
return_logits: bool = False,
- input_embeddings: List[InputEmbeddings] = None,
- mrope_position_ids: Tensor = None,
- mrope_position_delta: Tensor = None,
- cross_attention_states: Tensor = None) -> 'SchedulerSequence':
+ multimodals: MultiModalInputs = None,
+ input_embeddings: List[InputEmbeddings] = None
+ ) -> 'SchedulerSequence':
"""Add a new message."""
if isinstance(token_ids, Tensor):
token_ids = token_ids.numpy()
@@ -228,10 +228,8 @@ def add_sequence(
adapter_name=adapter_name,
arrive_time=time.time(),
history_embeddings=HistoryEmbeddings(input_embeddings),
+ history_multimodals=HistoryMultiModals(multimodals),
return_logits=return_logits,
- mrope_position_ids=mrope_position_ids,
- mrope_position_delta=mrope_position_delta,
- cross_attention_states=cross_attention_states,
)
self.sequences[seq.seq_id] = seq
if self.seq_manager is not None:
@@ -361,6 +359,66 @@ def copy(self):
return self.clone()
+class HistoryMultiModals:
+
+ def __init__(self, multimodals: MultiModalInputs):
+ if multimodals is None:
+ multimodals = dict()
+ self.multimodals = multimodals
+
+ def get_datas(self, start=0, end=-1):
+ """get multimodals from prompts position [start, end)."""
+ outs = dict()
+ test_range = range(start, end)
+ for modal_type, modal_datas in self.multimodals.items():
+ data = []
+ for modal_data in modal_datas:
+ if (modal_data.start not in test_range
+ and modal_data.end not in test_range):
+ continue
+ data.append(modal_data)
+ if len(data) > 0:
+ outs[modal_type] = data
+ return outs
+
+ def add_inputs(self, input_mms: MultiModalInputs):
+ """add new inputs."""
+ for modal_type, vals in input_mms.items():
+ if modal_type in self.multimodals:
+ self.multimodals[modal_type] += vals
+ else:
+ self.multimodals[modal_type] = vals
+
+ def empty(self):
+ if len(self.multimodals) == 0:
+ return 0
+
+ return all(len(vals) == 0 for vals in self.multimodals)
+
+ @staticmethod
+ def update_multimodals(input_mms: MultiModalInputs, prev_len: int):
+ """update multimodals."""
+ for vals in input_mms.values():
+ for val in vals:
+ val.start += prev_len
+ val.end += prev_len
+ return input_mms
+
+ def get_encoder_len(self, start=0, end=-1):
+ """get lens of encoder."""
+ test_range = range(start, end)
+ out_len = 0
+ for _, modal_datas in self.multimodals.items():
+ for modal_data in modal_datas:
+ if modal_data.encoder_len is None:
+ continue
+ if (modal_data.start not in test_range
+ and modal_data.end not in test_range):
+ continue
+ out_len += modal_data.encoder_len
+ return out_len
+
+
@dataclass
class SchedulerSequence:
"""Scheduler message."""
@@ -369,12 +427,12 @@ class SchedulerSequence:
history_cache: HistoryTokenIds = field(default_factory=HistoryTokenIds)
history_embeddings: HistoryEmbeddings = field(
default_factory=HistoryEmbeddings)
+ history_multimodals: HistoryMultiModals = field(
+ default_factory=HistoryMultiModals)
num_new_tokens: int = 0
sampling_param: SamplingParam = field(default_factory=SamplingParam)
logical_blocks: LogicalTokenBlocks = field(
default_factory=LogicalTokenBlocks)
- sender_id: int = -1
- req_id: int = -1
adapter_name: str = None
arrive_time: float = 0.0
meta: Any = None
@@ -382,10 +440,7 @@ class SchedulerSequence:
random_offsets: int = 0
_status: MessageStatus = field(default=MessageStatus.WAITING, init=False)
num_ignored_history: int = 0
- mrope_position_ids: Optional[Tensor] = None
- mrope_position_delta: Optional[int] = None
- cross_attention_states: Optional[Tensor] = None
- history_cross_kv_seqlens: int = 0
+ model_meta: Dict[str, Any] = None
def __post_init__(self):
"""post init."""
@@ -394,6 +449,10 @@ def __post_init__(self):
self._num_images: int = len(self.history_embeddings)
self._num_token_ids: int = len(self.history_cache)
+ self._num_history_cross: int = 0
+ self._num_cross: int = self.history_multimodals.get_encoder_len(
+ 0, self._num_token_ids)
+
@property
def block_size(self) -> int:
"""block size."""
@@ -464,6 +523,16 @@ def num_all_ids(self):
"""num all tokens."""
return self.history_len + self._num_token_ids
+ @property
+ def num_cross(self):
+ """num cross."""
+ return self._num_cross
+
+ @property
+ def num_history_cross(self):
+ """num history cross."""
+ return self._num_history_cross
+
@property
def num_blocks(self):
"""num blocks."""
@@ -489,22 +558,22 @@ def num_all_tokens(self):
def num_all_cross_tokens(self):
"""num of all cross tokens."""
- if self.cross_attention_states is None:
- self.history_cross_kv_seqlens = 0
- else:
- self.history_cross_kv_seqlens = self.cross_attention_states.shape[
- -2]
- return self.history_cross_kv_seqlens
+ return self._num_cross + self._num_history_cross
+
+ def get_input_multimodals(self):
+ """get input multimodals."""
+ start = self.num_history_ids
+ end = self.num_all_ids
+ return self.history_multimodals.get_datas(start, end)
def update_token_ids(self,
token_ids: Tensor,
+ multimodals: MultiModalInputs = None,
embeddings: List[InputEmbeddings] = None,
- cross_attention_states: List[Tensor] = None):
+ model_meta: Dict[str, Any] = None):
"""Update token ids, old token ids will be added to history."""
- # cross attention
- if cross_attention_states is not None:
- self.history_cross_kv_seqlens += cross_attention_states.shape[-2]
- self.cross_attention_states = cross_attention_states
+ old_num_history_ids = self._num_history_ids
+
self._num_history_ids += self._num_token_ids
# update history image nums
self._num_history_images += self._num_images
@@ -516,6 +585,23 @@ def update_token_ids(self,
self._num_images = len(new_embeddings)
self.history_embeddings.append(new_embeddings)
+ # update multimodals
+ if multimodals is not None:
+ multimodals = HistoryMultiModals.update_multimodals(
+ multimodals, self.num_all_ids)
+ self.history_multimodals.add_inputs(multimodals)
+
+ # cross
+ self._num_history_cross += self._num_cross
+ if multimodals is not None:
+ self._num_cross = self.history_multimodals.get_encoder_len(
+ old_num_history_ids, self._num_history_ids)
+ else:
+ self._num_cross = 0
+
+ if model_meta is not None:
+ self.model_meta = model_meta
+
if isinstance(token_ids, Tensor):
token_ids = token_ids.numpy()
elif not isinstance(token_ids, np.ndarray):
@@ -539,3 +625,12 @@ def set_step(self, step: int):
self._num_history_ids = step
self._num_token_ids = num_all_ids - step
self.num_ignored_history = min(step, self.num_ignored_history)
+
+ self.model_meta = None
+
+ # cross
+ if self.history_multimodals is not None:
+ self._num_history_cross = self.history_multimodals.get_encoder_len(
+ 0, self.num_history_ids)
+ self._num_cross = self.history_multimodals.get_encoder_len(
+ self._num_history_ids, num_all_ids)
diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py
index 0ae5dd7986..e984e39abe 100644
--- a/lmdeploy/pytorch/model_inputs.py
+++ b/lmdeploy/pytorch/model_inputs.py
@@ -4,47 +4,19 @@
from typing import Any, Dict, List, Literal
import torch
+from torch import distributed as dist
from lmdeploy.pytorch.backends import get_backend
+from lmdeploy.pytorch.config import ModelConfig
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
-@dataclass
-class MRopeModelInputs:
- """Multimodal rotary position inputs."""
- position_ids: List[torch.LongTensor] = None
- deltas: List[torch.LongTensor] = None
-
- def get_inputs(self, history_lengths: torch.Tensor,
- seq_lengths: torch.Tensor):
- mrope_position_ids = []
- for (his_len, seq_len, pos_ids,
- delta) in zip(history_lengths, seq_lengths, self.position_ids,
- self.deltas):
- assert pos_ids.dim() == 2, 'invalid mrope_position_ids'
- if his_len + seq_len <= pos_ids.shape[1]:
- mrope_position_ids.append(pos_ids[:,
- his_len:his_len + seq_len])
- else:
- mrope_position_ids.append(
- torch.tensor([his_len], device=delta.device).expand(3, -1)
- + delta)
-
- mrope_position_ids = torch.cat(mrope_position_ids, dim=-1)
- return mrope_position_ids
-
- def to_device(self, device: str):
- """to device."""
- out_dict = dict()
- for f in fields(self):
- k = f.name
- v = getattr(self, k)
- if isinstance(v, torch.Tensor):
- v = v.to(device)
- elif isinstance(v, list):
- v = [x.to(device) for x in v]
- out_dict[k] = v
-
- return MRopeModelInputs(**out_dict)
+def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'):
+ """broadcast tensor."""
+ if value.device.type == 'meta':
+ value = torch.empty_like(value, device=device)
+ dist.broadcast(value, src)
+ return value
@dataclass
@@ -56,6 +28,7 @@ class VisionModelInputs:
input_embeddings: List[List[torch.Tensor]] = None
input_embedding_ranges: List[torch.LongTensor] = None
input_embedding_indexing: torch.BoolTensor = None
+ input_multimodals: List[MultiModalTensor] = None
def to_device(self, device: str):
"""to device."""
@@ -63,12 +36,54 @@ def to_device(self, device: str):
for f in fields(self):
k = f.name
v = getattr(self, k)
+ if v is None:
+ continue
if isinstance(v, torch.Tensor):
v = v.to(device)
- elif k == 'input_embedding_ranges' and v is not None:
+ elif k == 'input_embedding_ranges':
v = [e.to(device) for e in v]
- elif k == 'input_embeddings' and v is not None:
+ elif k == 'input_embeddings':
v = [[e.to(device) for e in li] for li in v]
+ elif k == 'input_multimodals':
+ new_v = []
+ for mm_datas in v:
+ new_mm_datas = dict()
+ for modal_type, data in mm_datas.items():
+ data = [d.to_device(device) for d in data]
+ new_mm_datas[modal_type] = data
+ new_v.append(new_mm_datas)
+ v = new_v
+ out_dict[k] = v
+
+ return VisionModelInputs(**out_dict)
+
+ def broadcast(self):
+ """broadcast inputs.
+
+ Do `dist.broadcast_object_list(inputs.to_device('meta'))`
+ before broadcast tensors.
+ """
+ out_dict = dict()
+ for f in fields(self):
+ k = f.name
+ v = getattr(self, k)
+ if v is None:
+ continue
+ if isinstance(v, torch.Tensor):
+ v = _broadcast_tensor(v)
+ elif k == 'input_embedding_ranges':
+ v = [_broadcast_tensor(e) for e in v]
+ elif k == 'input_embeddings':
+ v = [[_broadcast_tensor(e) for e in li] for li in v]
+ elif k == 'input_multimodals':
+ new_v = []
+ for mm_datas in v:
+ new_mm_datas = dict()
+ for modal_type, data in mm_datas.items():
+ data = [d.broadcast() for d in data]
+ new_mm_datas[modal_type] = data
+ new_v.append(new_mm_datas)
+ v = new_v
out_dict[k] = v
return VisionModelInputs(**out_dict)
@@ -119,12 +134,12 @@ class ModelInputs:
num_ignored_history: torch.LongTensor
local_adapter_ids: torch.LongTensor = None
vision_inputs: VisionModelInputs = None
- mrope_inputs: MRopeModelInputs = None
- cross_attention_states: torch.Tensor = None
- history_cross_kv_seqlens: torch.LongTensor = None
+ cross_length: torch.LongTensor = None
+ history_cross_length: torch.LongTensor = None
last_hidden_states: torch.Tensor = None
medusa_attn_mask: torch.Tensor = None
medusa_position_ids: torch.Tensor = None
+ model_metas: List[Dict[str, Any]] = None
def update(self, input_ids: torch.LongTensor):
"""update input ids."""
@@ -135,44 +150,88 @@ def update(self, input_ids: torch.LongTensor):
self.input_ids = input_ids
return self
- def split(self, split_size: int, block_size: int):
+ def split(self, split_size: int):
"""split inputs."""
assert len(
self.seq_length) == 1, ('Can not perform split on batched input.')
- assert split_size % block_size == 0, (
- 'split_size should be multi of block_size.')
input_ids = self.input_ids
if input_ids.numel() < split_size:
return self
- num_blocks = split_size // block_size
- overlap = (self.history_lengths[0] % block_size != 0)
+ flatten_mms = []
+ vision_inputs = self.vision_inputs
+ if vision_inputs is not None:
+ if vision_inputs.input_multimodals is not None:
+ input_mms = vision_inputs.input_multimodals[0]
+
+ flatten_mms = []
+ for k, mms in input_mms.items():
+ mms = [(k, mm) for mm in mms]
+ flatten_mms += mms
+
+ flatten_mms = sorted(flatten_mms, key=lambda mm: mm[1].start)
+
max_seq_len = self.seq_length[0].item()
ret = []
- block_start = 0
- for i in range(0, max_seq_len, split_size):
- start = i
- end = min(max_seq_len, i + split_size)
- block_end = block_start + num_blocks
- if overlap:
- block_end += 1
-
- block_offsets = self.block_offsets
+ start = 0
+ history_cross_length = self.history_cross_length
+ cross_length = None
+ if history_cross_length is not None:
+ cross_length = self.history_cross_length.clone()
+ while start < max_seq_len:
+ vision_inputs = None
+ if len(flatten_mms) > 0:
+ mm_start = flatten_mms[0][1].start
+ mm_end = flatten_mms[0][1].end
+ if mm_start > self.history_lengths + start:
+ end = min(mm_start - self.history_lengths,
+ start + split_size)
+ else:
+ input_mms = dict()
+ key, mm = flatten_mms.pop(0)
+ input_mms.setdefault(key, [])
+ input_mms[key].append(mm)
+ end = start + mm.end - mm.start
+ while len(flatten_mms) > 0:
+ next_mm = flatten_mms[0]
+ next_start = next_mm[1].start
+ next_end = next_mm[1].end
+ if next_start < mm_end:
+ key = next_mm[0]
+ input_mms.setdefault(key, [])
+ input_mms[key].append(next_mm[1])
+ end += max(0, next_end - mm_end)
+ flatten_mms.pop(0)
+
+ if cross_length is not None:
+ encoder_len = next_mm[1].encoder_len
+ if encoder_len is not None:
+ cross_length += encoder_len
+ else:
+ break
+ vision_inputs = VisionModelInputs(
+ input_multimodals=[input_mms], )
+ else:
+ end = min(max_seq_len, start + split_size)
+
inp = ModelInputs(
input_ids=self.input_ids[:, start:end],
seq_length=input_ids.new_tensor([end - start]),
- block_offsets=block_offsets,
+ block_offsets=self.block_offsets,
history_lengths=self.history_lengths + start,
is_decoding=self.is_decoding,
num_ignored_history=self.num_ignored_history,
local_adapter_ids=self.local_adapter_ids,
- vision_inputs=self.vision_inputs,
- mrope_inputs=self.mrope_inputs,
- cross_attention_states=self.cross_attention_states,
+ vision_inputs=vision_inputs,
+ model_metas=self.model_metas,
+ cross_length=cross_length,
+ history_cross_length=history_cross_length,
)
ret.append(inp)
- block_start += num_blocks
+ history_cross_length = cross_length
+
+ start = end
return ret
@@ -186,8 +245,24 @@ def to_device(self, device: str):
v = v.to(device)
elif isinstance(v, VisionModelInputs):
v = v.to_device(device)
- elif isinstance(v, MRopeModelInputs):
- v = v.to_device(device)
+ out_dict[k] = v
+
+ return ModelInputs(**out_dict)
+
+ def broadcast(self):
+ """broadcast inputs.
+
+ Do `dist.broadcast_object_list(inputs.to_device('meta'))`
+ before broadcast tensors.
+ """
+ out_dict = dict()
+ for f in fields(self):
+ k = f.name
+ v = getattr(self, k)
+ if isinstance(v, torch.Tensor):
+ v = _broadcast_tensor(v)
+ elif isinstance(v, VisionModelInputs):
+ v = v.broadcast()
out_dict[k] = v
return ModelInputs(**out_dict)
@@ -201,6 +276,7 @@ class StepContext:
dataclass provide these infos and tools.
"""
input_ids: torch.LongTensor
+ model_config: ModelConfig
block_offsets: torch.LongTensor
position_ids: torch.LongTensor
attention_mask: torch.LongTensor
@@ -213,15 +289,16 @@ class StepContext:
local_adapter_ids: torch.LongTensor = None
input_embeddings: torch.Tensor = None
input_embedding_indexing: torch.Tensor = None
+ input_multimodals: List[MultiModalTensor] = None
vision_inputs: VisionModelInputs = None
- mrope_position_ids: torch.Tensor = None
attn_metadata: Any = None
- cross_attn_metadata: Any = None
- cross_attention_states: torch.Tensor = None
+ cross_seqlens: torch.LongTensor = None
cross_kv_seqlens: torch.LongTensor = None
+ cross_attn_metadata: Any = None
kv_quant_policy: Literal[0, 4, 8] = 0
last_hidden_states: torch.Tensor = None
medusa_attn_mask: torch.Tensor = None
+ model_metas: List[Dict[str, Any]] = None
_outputs: Dict = field(default_factory=dict)
@@ -229,6 +306,7 @@ class StepContext:
def new(
cls,
inputs: ModelInputs,
+ model_config: ModelConfig,
world_size: int = 1,
kv_caches: List = None,
kv_quant_policy: Literal[0, 4, 8] = 0,
@@ -244,34 +322,38 @@ def new(
history_seqlens = inputs.history_lengths
device = q_seqlens.device
+ input_multimodals = None
+ if inputs.vision_inputs is not None:
+ input_multimodals = inputs.vision_inputs.input_multimodals
+
# for vlm
input_embeddings, input_embedding_indexing = None, None
if (inputs.vision_inputs is not None
and inputs.vision_inputs.input_embeddings is not None):
input_embeddings, input_embedding_indexing = \
inputs.vision_inputs.get_inputs(history_seqlens, q_seqlens)
- # for mrope
- mrope_position_ids = None
- if inputs.mrope_inputs is not None:
- mrope_position_ids = inputs.mrope_inputs.get_inputs(
- history_seqlens, q_seqlens)
# for speculative decoding
last_hidden_states = inputs.last_hidden_states
# kv_seqlens
- cross_attention_states = inputs.cross_attention_states
if inputs.is_decoding:
attention_mask = torch.ones_like(q_seqlens)[:, None]
- position_ids = history_seqlens.unsqueeze(-1)
- cross_attention_states = None
+ position_ids = history_seqlens.unsqueeze(-1).clone()
else:
- max_q_seqlen = q_seqlens.max().item()
+ max_q_seqlen = q_seqlens.contiguous().max().item()
mask_range = torch.arange(max_q_seqlen, device=device)[None, :]
attention_mask = (mask_range < q_seqlens[:, None]).long()
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids += history_seqlens.unsqueeze(-1)
q_start_loc = q_seqlens.cumsum(0) - q_seqlens
+ # cross
+ cross_seqlens = inputs.cross_length
+ cross_kv_seqlens = None
+ if inputs.cross_length is not None:
+ cross_kv_seqlens = (inputs.cross_length +
+ inputs.history_cross_length)
+
# position ids 1d
position_ids = cls.get_position_ids_1d(position_ids, q_seqlens)[None]
# seq_len + history_length
@@ -283,10 +365,12 @@ def new(
ret = StepContext(
input_ids=inputs.input_ids,
+ model_config=model_config,
block_offsets=inputs.block_offsets,
position_ids=position_ids,
input_embeddings=input_embeddings,
input_embedding_indexing=input_embedding_indexing,
+ input_multimodals=input_multimodals,
attention_mask=attention_mask,
q_seqlens=q_seqlens,
kv_seqlens=kv_seqlens,
@@ -296,12 +380,12 @@ def new(
world_size=world_size,
local_adapter_ids=inputs.local_adapter_ids,
vision_inputs=inputs.vision_inputs,
- mrope_position_ids=mrope_position_ids,
- cross_attention_states=cross_attention_states,
last_hidden_states=last_hidden_states,
medusa_attn_mask=inputs.medusa_attn_mask,
- cross_kv_seqlens=inputs.history_cross_kv_seqlens,
kv_quant_policy=kv_quant_policy,
+ model_metas=inputs.model_metas,
+ cross_seqlens=cross_seqlens,
+ cross_kv_seqlens=cross_kv_seqlens,
)
ret = get_backend().update_step_context(ret)
@@ -330,6 +414,7 @@ def __init__(self):
@staticmethod
def build_context(
inputs: ModelInputs,
+ model_config: ModelConfig,
world_size: int = 1,
kv_caches: List = None,
kv_quant_policy: Literal[0, 4, 8] = 0,
@@ -337,6 +422,7 @@ def build_context(
"""build context."""
return StepContext.new(
inputs,
+ model_config,
world_size,
kv_caches,
kv_quant_policy,
diff --git a/lmdeploy/pytorch/models/baichuan.py b/lmdeploy/pytorch/models/baichuan.py
index 583cd19fe9..38d794f1be 100644
--- a/lmdeploy/pytorch/models/baichuan.py
+++ b/lmdeploy/pytorch/models/baichuan.py
@@ -228,7 +228,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -245,7 +244,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py
index 8d7a21a0a6..5a83154167 100644
--- a/lmdeploy/pytorch/models/chatglm2.py
+++ b/lmdeploy/pytorch/models/chatglm2.py
@@ -1,101 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Any, Iterable, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
+from torch.nn import functional as F
from transformers.configuration_utils import PretrainedConfig
+from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor,
+ PreprocessInputResult)
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
SiluAndMul, build_rotary_embedding)
-from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear,
+from lmdeploy.pytorch.nn.linear import (build_colwise_linear,
+ build_merged_colwise_linear,
build_qkv_proj, build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .utils.cudagraph import CudaGraphMixin
+from .utils.model import DeployModelMixin
LANGUAGE_TOKEN_TYPE = 0
VISION_TOKEN_TYPE = 1
-def get_vision_expert_mask(token_type_ids: torch.LongTensor):
- vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
- vision_token_mask[:, :-1] = (token_type_ids[:, :-1]
- == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:]
- == VISION_TOKEN_TYPE)
- language_token_mask = ~vision_token_mask
- return vision_token_mask, language_token_mask
-
-
-def build_position_ids(x: torch.BoolTensor) -> torch.LongTensor:
- tmp = x.clone()
- # image boi eoi token as LANGUAGE_TOKEN_TYPE
- is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
- is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (
- tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
- is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
- is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (
- tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
- is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
- tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
- # final position ids
- y = torch.zeros_like(x, dtype=torch.long)
- y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
- (tmp[:, 1:] == VISION_TOKEN_TYPE) &
- (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
- y = y.cumsum(dim=-1)
- return y
-
-
-def _get_cogvlm_position_ids(context):
- """get cogvlm position_ids."""
- q_seqlens = context.q_seqlens
- history_lengths = context.kv_seqlens - q_seqlens
- vision_input_info = context.vision_inputs
- position_id_offsets = (vision_input_info.history_image_token_lengths -
- vision_input_info.history_image_nums * 3)
- lang_ids = None
- vis_ids = None
- if context.is_decoding:
- position_ids = history_lengths - position_id_offsets
- else:
- if vision_input_info.input_embeddings is not None and len(
- vision_input_info.input_embeddings) > 0:
- starts = history_lengths - vision_input_info.history_lengths
- ends = starts + q_seqlens
- token_type_ids = vision_input_info.input_embedding_indexing.to(
- torch.int)
- history_position_lengths = (vision_input_info.history_lengths -
- position_id_offsets)
- position_ids_all = (history_position_lengths[:, None] +
- build_position_ids(token_type_ids))
- position_ids = torch.cat([
- pids[s:e]
- for (pids, s, e) in zip(position_ids_all, starts, ends)
- ])
- vision_token_mask_all, _ = get_vision_expert_mask(token_type_ids)
- vision_token_mask = torch.cat([
- masks[s:e]
- for (masks, s, e) in zip(vision_token_mask_all, starts, ends)
- ])
- mask_indexing = torch.arange(vision_token_mask.shape[-1],
- device=vision_token_mask.device)
- vis_ids = mask_indexing[vision_token_mask]
- lang_ids = mask_indexing[~vision_token_mask]
-
- else:
- position_ids = context.attention_mask.long().cumsum(-1) - 1
- position_ids += (history_lengths -
- position_id_offsets).unsqueeze(-1)
- device = position_ids.device
- position_ids_1d = [
- ids[:l] for ids, l in zip(position_ids.cpu(), q_seqlens.cpu())
- ]
- position_ids = torch.cat(position_ids_1d).to(device)
-
- return position_ids, lang_ids, vis_ids
-
-
class SelfAttention(torch.nn.Module):
"""Parallel self-attention layer abstract class.
@@ -112,11 +40,10 @@ def __init__(self,
self.projection_size = config.kv_channels * config.num_attention_heads
self.num_attention_heads = config.num_attention_heads
- self.num_kv_heads = self.num_attention_heads
+ self.num_kv_heads = config.num_key_value_heads
self.head_size = (self.projection_size // config.num_attention_heads)
- self.multi_query_attention = config.multi_query_attention
- if self.multi_query_attention:
- self.num_kv_heads = config.multi_query_group_num
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
self.query_key_value = build_qkv_proj(
config.hidden_size,
num_q_heads=self.num_attention_heads,
@@ -126,7 +53,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
# apply rotary
self.apply_rotary_pos_emb = ApplyRotaryEmb()
@@ -338,7 +265,6 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
- quantization_config = getattr(config, 'quantization_config', None)
self.num_layers = config.num_layers
self.post_layer_norm = config.post_layer_norm
@@ -353,7 +279,6 @@ def build_layer(layer_number):
assert config.rmsnorm
self.final_layernorm = RMSNorm(config.hidden_size,
config.layernorm_epsilon,
- quant_config=quantization_config,
dtype=dtype,
device=device)
@@ -410,6 +335,286 @@ def forward(self, input_ids):
return embeddings
+class PatchEmbedding(nn.Module):
+ """vision embedding."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.proj = nn.Conv2d(config.in_channels,
+ config.hidden_size,
+ kernel_size=config.patch_size,
+ stride=config.patch_size,
+ dtype=dtype,
+ device=device)
+ self.cls_embedding = nn.Parameter(
+ torch.empty(1, config.hidden_size, dtype=dtype, device=device))
+ self.position_embedding = nn.Embedding(config.num_positions,
+ config.hidden_size,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, images):
+ """forward."""
+ x = self.proj(images)
+ x = x.flatten(2).transpose(1, 2)
+ cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
+ x = torch.cat((cls_token, x), dim=1)
+ x += self.position_embedding.weight.unsqueeze(0)
+ return x
+
+
+class EVA2CLIPAttention(nn.Module):
+ """vision attention."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
+ hidden_size = config.hidden_size
+ num_heads = config.num_heads
+ head_dim = config.hidden_size // config.num_heads
+ self.scale = head_dim**-0.5
+
+ # packed qkv
+ self.query_key_value = build_qkv_proj(
+ hidden_size,
+ num_q_heads=num_heads,
+ num_kv_heads=num_heads,
+ head_size=head_dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ )
+
+ # o_proj
+ self.dense = build_rowwise_linear(hidden_size,
+ hidden_size,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """forward."""
+ # qkv proj
+ qkv_states = self.query_key_value(hidden_states)
+ q, k, v = self.query_key_value.split_qkv(qkv_states)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
+
+ # o proj
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.flatten(-2, -1)
+ attn_output = self.dense(attn_output)
+ return attn_output
+
+
+class EVA2CLIPMLP(nn.Module):
+ """vision MLP."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ from transformers.activations import ACT2FN
+
+ # gate up
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.fc1 = build_colwise_linear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+
+ # silu and mul
+ if config.hidden_act in [
+ 'gelu', 'gelu_fast', 'quick_gelu', 'gelu_python'
+ ]:
+ self.activation_fn = nn.GELU()
+ else:
+ self.activation_fn = ACT2FN[config.hidden_act]
+
+ # down
+ self.fc2 = build_rowwise_linear(config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """forward."""
+ x = self.fc1(x)
+ x = self.activation_fn(x)
+ x = self.fc2(x)
+ return x
+
+
+class EVA2CLIPTransformerLayer(nn.Module):
+ """vision trans layer."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.input_layernorm = nn.LayerNorm(config.hidden_size,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+ self.attention = EVA2CLIPAttention(config, dtype=dtype, device=device)
+ self.mlp = EVA2CLIPMLP(config, dtype=dtype, device=device)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, hidden_states):
+ """forward."""
+ attention_input = hidden_states
+ attention_output = self.input_layernorm(
+ self.attention(attention_input))
+ hidden_states = attention_input + attention_output
+ mlp_input = hidden_states
+ mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
+ output = mlp_input + mlp_output
+ return output
+
+
+class EVA2CLIPTransformer(nn.Module):
+ """vision transformer."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.layers = nn.ModuleList([
+ EVA2CLIPTransformerLayer(config, dtype=dtype, device=device)
+ for _ in range(config.num_hidden_layers)
+ ])
+
+ def forward(self, hidden_states):
+ """forward."""
+ for layer_module in self.layers:
+ hidden_states = layer_module(hidden_states)
+ return hidden_states
+
+
+class GLU(nn.Module):
+ """GLU."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ in_features: int,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.linear_proj = nn.Linear(in_features,
+ config.hidden_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+ self.norm1 = nn.LayerNorm(config.hidden_size,
+ dtype=dtype,
+ device=device)
+ self.act1 = nn.GELU()
+ self.act2 = nn.functional.silu
+ self.dense_h_to_4h = nn.Linear(config.hidden_size,
+ config.ffn_hidden_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+ self.gate_proj = nn.Linear(config.hidden_size,
+ config.ffn_hidden_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+ self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size,
+ config.hidden_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, x):
+ x = self.linear_proj(x)
+ x = self.act1(self.norm1(x))
+ x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
+ x = self.dense_4h_to_h(x)
+ return x
+
+
+class EVA2CLIPModel(nn.Module):
+ """vision model."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ from argparse import Namespace
+ vision_config = Namespace(**config.vision_config)
+
+ self.patch_embedding = PatchEmbedding(vision_config,
+ dtype=dtype,
+ device=device)
+ self.transformer = EVA2CLIPTransformer(vision_config,
+ dtype=dtype,
+ device=device)
+ self.linear_proj = GLU(config,
+ in_features=config.hidden_size,
+ dtype=dtype,
+ device=device)
+ self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
+ out_channels=config.hidden_size,
+ kernel_size=2,
+ stride=2,
+ dtype=dtype,
+ device=device)
+ self.boi = nn.Parameter(
+ torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))
+ self.eoi = nn.Parameter(
+ torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))
+ self.scaling_factor = vision_config.scaling_factor
+
+ def forward(self, images):
+ """forward."""
+ x = self.patch_embedding(images)
+ x = self.transformer(x)
+
+ x = x[:, 1:]
+
+ b, s, h = x.shape
+ grid_size = int(s**0.5)
+ x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
+ x = self.conv(x)
+
+ x = x.flatten(2).transpose(1, 2)
+ x = self.linear_proj(x)
+ boi = self.boi.expand(x.shape[0], -1, -1)
+ eoi = self.eoi.expand(x.shape[0], -1, -1)
+ x = torch.cat((boi, x, eoi), dim=1)
+ x = x / self.scaling_factor
+ return x
+
+
class ChatGLMModel(nn.Module):
def __init__(self,
@@ -442,19 +647,32 @@ def __init__(self,
dtype=dtype,
device=device)
+ self.vision = None
+ if hasattr(config, 'vision_config'):
+ self.vision = EVA2CLIPModel(config, dtype=dtype, device=device)
+
def forward(
self,
input_ids: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
attn_metadata: Any = None,
+ images: torch.Tensor = None,
+ image_mask: torch.Tensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
):
"""forward."""
# token embedding
if inputs_embeds is None:
+ images_features = None
+ if images is not None:
+ images_features = self.vision(images)
+ images_features = images_features.flatten(0, 1)[None]
inputs_embeds = self.embedding(input_ids)
+ if images is not None:
+ inputs_embeds.masked_scatter_(image_mask[..., None],
+ images_features)
hidden_states = inputs_embeds
@@ -477,7 +695,8 @@ def get_input_embeddings(self):
return self.embedding
-class ChatGLMForConditionalGeneration(nn.Module, CudaGraphMixin):
+class ChatGLMForConditionalGeneration(nn.Module, DeployModelMixin,
+ CudaGraphMixin):
"""rewrote model of LlamaForCausalLM."""
def __init__(self,
@@ -491,12 +710,16 @@ def __init__(self,
# build Model
self.transformer = ChatGLMModel(config, dtype=dtype, device=device)
+ self.input_processor = ChatGLMInputProcessor(self.config, dtype)
+
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: List[List[torch.Tensor]],
attn_metadata: Any = None,
+ images: torch.Tensor = None,
+ image_mask: torch.Tensor = None,
inputs_embeds: torch.Tensor = None,
**kwargs,
):
@@ -506,6 +729,8 @@ def forward(
position_ids=position_ids,
past_key_values=past_key_values,
attn_metadata=attn_metadata,
+ images=images,
+ image_mask=image_mask,
inputs_embeds=inputs_embeds,
)
return hidden_states
@@ -529,8 +754,23 @@ def prepare_inputs_for_generation(
input_ids = context.input_ids
position_ids = context.position_ids
attn_metadata = context.attn_metadata
- if context.vision_inputs is not None:
- position_ids = _get_cogvlm_position_ids(context)[0][None]
+
+ images = None
+ image_mask = None
+ if context.input_multimodals is not None:
+ images = [
+ input_mm.get('image', [])
+ for input_mm in context.input_multimodals
+ ]
+ # flatten batch
+ images = [data for im_data in images for data in im_data]
+ if len(images) != 0:
+ image_token_id = images[0].meta['image_token_id']
+ image_mask = input_ids == image_token_id
+ images = torch.stack([data.data for data in images])
+ else:
+ images = None
+ image_mask = None
# process vision embeddings
vision_embeddings = context.input_embeddings
@@ -548,9 +788,92 @@ def prepare_inputs_for_generation(
position_ids=position_ids,
past_key_values=past_key_values,
attn_metadata=attn_metadata,
+ images=images,
+ image_mask=image_mask,
inputs_embeds=inputs_embeds,
)
+ def update_model_metas(self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None):
+ """update model meta."""
+ model_metas = context.model_metas
+ if not hasattr(self.config, 'vision_config'):
+ return model_metas
+
+ input_multimodals = context.input_multimodals
+ if input_multimodals is None:
+ input_imgs = [[] for _ in model_metas]
+ else:
+ input_imgs = []
+ for mm in input_multimodals:
+ if mm is None:
+ input_imgs.append([])
+ else:
+ input_imgs.append(mm.get('image', []))
+
+ config = self.config
+ image_size: int = config.vision_config['image_size']
+ patch_size: int = config.vision_config['patch_size']
+ vision_token_num = ((image_size // patch_size // 2) *
+ (image_size // patch_size // 2) + 2)
+ num_pad = vision_token_num - 3
+
+ batched_num_img_tokens = []
+ new_model_metas = []
+ for meta, imgs in zip(model_metas, input_imgs):
+ if meta is None:
+ num_img_tokens = 0
+ else:
+ num_img_tokens = meta.get('num_img_tokens', 0)
+
+ batched_num_img_tokens.append(num_img_tokens)
+
+ num_img_tokens += num_pad * len(imgs)
+ new_model_metas.append(dict(num_img_tokens=num_img_tokens))
+
+ # prepare cogvlm position_ids
+ q_seqlens = context.q_seqlens
+ position_ids = context.position_ids
+
+ if context.is_decoding or all(len(imgs) == 0 for imgs in input_imgs):
+ num_img_tokens = torch.tensor(batched_num_img_tokens,
+ device=position_ids.device)
+ position_ids -= num_img_tokens[None]
+ else:
+ batched_position_ids = position_ids[0].split(q_seqlens)
+ for pos_ids, num_img_tok, imgs in zip(batched_position_ids,
+ batched_num_img_tokens,
+ input_imgs):
+ pos_ids -= num_img_tok
+ if len(imgs) == 0:
+ continue
+
+ seq_len = pos_ids.size(0)
+ start = pos_ids[0].cpu().item()
+ new_pos_ids = []
+
+ imgs = sorted(imgs, key=lambda img: img.start)
+ for img in imgs:
+ img_pad_pos = img.start + 1 - num_img_tok
+ num_pad = img.end - img.start - 2
+ new_pos_ids += list(range(start, img_pad_pos))
+ new_pos_ids += [img_pad_pos] * num_pad
+ start = img_pad_pos + 1
+ num_img_tok += num_pad
+
+ remain = seq_len - len(new_pos_ids)
+ new_pos_ids += list(range(start, start + remain))
+
+ new_pos_ids = pos_ids.new_tensor(new_pos_ids)
+ pos_ids[:] = new_pos_ids
+
+ position_ids = torch.cat(batched_position_ids)[None]
+ context.position_ids = position_ids
+
+ return new_model_metas
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""load weights."""
# modify from vllm
@@ -558,7 +881,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if 'transformer.vision' in name:
+ if '.query_key_value' in name:
+ param = params_dict[name]
+ q, k, v = param.weight_spliter(loaded_weight)
+ load_weight(param, q, shard_id='q')
+ load_weight(param, k, shard_id='k')
+ load_weight(param, v, shard_id='v')
+ else:
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
continue
+
if 'rotary_pos_emb.inv_freq' in name:
continue
if ('rotary_pos_emb.cos_cached' in name
@@ -581,3 +914,53 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
param = params_dict[name]
load_weight(param, loaded_weight)
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+
+class ChatGLMInputProcessor(BaseModelInputProcessor):
+ """input processor."""
+
+ def __init__(self, config: PretrainedConfig, dtype) -> None:
+ self.config = config
+ self.dtype = dtype
+
+ if hasattr(config, 'vision_config'):
+ vision_config = config.vision_config
+ self.image_size = vision_config['image_size']
+ self.patch_size = vision_config['patch_size']
+ self.num_patches = (self.image_size // self.patch_size)**2
+ self.num_positions = self.num_patches + 1
+ self.vision_token_num = self.num_patches // 4
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_multimodals: List[Dict[str, Any]] = None,
+ **kwargs) -> PreprocessInputResult:
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values'].to(self.dtype)
+ offset = input_mm['offset']
+ num_pad = input_mm['image_tokens']
+ image_token_id = input_mm.get('image_token_id', 0)
+ if isinstance(num_pad, torch.Tensor):
+ num_pad = num_pad.item()
+
+ mm_data = MultiModalTensor(
+ data=pixel_values,
+ start=offset,
+ end=offset + num_pad,
+ meta=dict(image_token_id=image_token_id))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py
index 6caf10df00..8010e5cead 100644
--- a/lmdeploy/pytorch/models/cogvlm.py
+++ b/lmdeploy/pytorch/models/cogvlm.py
@@ -1,20 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from argparse import Namespace
from typing import Any, Iterable, List, Optional, Tuple
import torch
import torch.distributed as dist
+import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from lmdeploy.pytorch.distributed import get_world_rank
+from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor,
+ PreprocessInputResult)
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
SiluAndMul, build_rotary_embedding)
-from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear,
+from lmdeploy.pytorch.nn.linear import (build_colwise_linear,
+ build_merged_colwise_linear,
build_qkv_proj, build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .utils.cudagraph import CudaGraphMixin
+from .utils.model import DeployModelMixin
class VisionExpertAttention(nn.Module):
@@ -28,8 +35,9 @@ def __init__(self,
is_cogvlm2 = hasattr(config, 'num_multi_query_heads')
quantization_config = getattr(config, 'quantization_config', None)
num_heads = config.num_attention_heads
- num_key_value_heads = getattr(config, 'num_multi_query_heads',
- num_heads)
+ num_key_value_heads = getattr(config, 'num_key_value_heads', num_heads)
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
hidden_size = config.hidden_size
head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
self.hidden_size = hidden_size
@@ -46,7 +54,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
self.language_expert_query_key_value = build_qkv_proj(
hidden_size,
num_q_heads=num_heads,
@@ -56,7 +64,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
# rotary embedding
self.apply_rotary_pos_emb = ApplyRotaryEmb()
@@ -322,6 +330,283 @@ def forward(
return outputs
+class PatchEmbedding(nn.Module):
+ """vision embedding."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.proj = nn.Conv2d(config.in_channels,
+ config.hidden_size,
+ kernel_size=config.patch_size,
+ stride=config.patch_size,
+ dtype=dtype,
+ device=device)
+ self.cls_embedding = nn.Parameter(
+ torch.empty(1, config.hidden_size, dtype=dtype, device=device))
+ self.position_embedding = nn.Embedding(config.num_positions,
+ config.hidden_size,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, images):
+ """forward."""
+ x = self.proj(images)
+ x = x.flatten(2).transpose(1, 2)
+ cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
+ x = torch.cat((cls_token, x), dim=1)
+ x += self.position_embedding.weight.unsqueeze(0)
+ return x
+
+
+class EVA2CLIPAttention(nn.Module):
+ """vision attention."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
+ hidden_size = config.hidden_size
+ num_heads = config.num_heads
+ head_dim = config.hidden_size // config.num_heads
+ self.scale = head_dim**-0.5
+
+ # packed qkv
+ self.query_key_value = build_qkv_proj(
+ hidden_size,
+ num_q_heads=num_heads,
+ num_kv_heads=num_heads,
+ head_size=head_dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ )
+
+ # o_proj
+ self.dense = build_rowwise_linear(hidden_size,
+ hidden_size,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """forward."""
+ # qkv proj
+ qkv_states = self.query_key_value(hidden_states)
+ q, k, v = self.query_key_value.split_qkv(qkv_states)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
+
+ # o proj
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.flatten(-2, -1)
+ attn_output = self.dense(attn_output)
+ return attn_output
+
+
+class EVA2CLIPMLP(nn.Module):
+ """vision MLP."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ from transformers.activations import ACT2FN
+
+ # gate up
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.fc1 = build_colwise_linear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+
+ # silu and mul
+ if config.hidden_act in [
+ 'gelu', 'gelu_fast', 'quick_gelu', 'gelu_python'
+ ]:
+ self.activation_fn = nn.GELU()
+ else:
+ self.activation_fn = ACT2FN[config.hidden_act]
+
+ # down
+ self.fc2 = build_rowwise_linear(config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """forward."""
+ x = self.fc1(x)
+ x = self.activation_fn(x)
+ x = self.fc2(x)
+ return x
+
+
+class EVA2CLIPTransformerLayer(nn.Module):
+ """vision trans layer."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.input_layernorm = nn.LayerNorm(config.hidden_size,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+ self.attention = EVA2CLIPAttention(config, dtype=dtype, device=device)
+ self.mlp = EVA2CLIPMLP(config, dtype=dtype, device=device)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, hidden_states):
+ """forward."""
+ attention_input = hidden_states
+ attention_output = self.input_layernorm(
+ self.attention(attention_input))
+ hidden_states = attention_input + attention_output
+ mlp_input = hidden_states
+ mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
+ output = mlp_input + mlp_output
+ return output
+
+
+class EVA2CLIPTransformer(nn.Module):
+ """vision transformer."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.layers = nn.ModuleList([
+ EVA2CLIPTransformerLayer(config, dtype=dtype, device=device)
+ for _ in range(config.num_hidden_layers)
+ ])
+
+ def forward(self, hidden_states):
+ """forward."""
+ for layer_module in self.layers:
+ hidden_states = layer_module(hidden_states)
+ return hidden_states
+
+
+class GLU(nn.Module):
+ """GLU."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ in_features: int,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.linear_proj = nn.Linear(in_features,
+ config.hidden_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+ self.norm1 = nn.LayerNorm(config.hidden_size,
+ dtype=dtype,
+ device=device)
+ self.act1 = nn.GELU()
+ self.act2 = nn.functional.silu
+ self.dense_h_to_4h = nn.Linear(config.hidden_size,
+ config.intermediate_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+ self.gate_proj = nn.Linear(config.hidden_size,
+ config.intermediate_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+ self.dense_4h_to_h = nn.Linear(config.intermediate_size,
+ config.hidden_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, x):
+ x = self.linear_proj(x)
+ x = self.act1(self.norm1(x))
+ x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
+ x = self.dense_4h_to_h(x)
+ return x
+
+
+class EVA2CLIPModel(nn.Module):
+ """vision model."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ vision_config = Namespace(**config.vision_config)
+
+ self.patch_embedding = PatchEmbedding(vision_config,
+ dtype=dtype,
+ device=device)
+ self.transformer = EVA2CLIPTransformer(vision_config,
+ dtype=dtype,
+ device=device)
+ self.linear_proj = GLU(config,
+ in_features=vision_config.hidden_size,
+ dtype=dtype,
+ device=device)
+ self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
+ out_channels=vision_config.hidden_size,
+ kernel_size=2,
+ stride=2,
+ dtype=dtype,
+ device=device)
+ self.boi = nn.Parameter(
+ torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))
+ self.eoi = nn.Parameter(
+ torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))
+
+ def forward(self, images):
+ """forward."""
+ x = self.patch_embedding(images)
+ x = self.transformer(x)
+
+ x = x[:, 1:]
+
+ b, s, h = x.shape
+ grid_size = int(s**0.5)
+ x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
+ x = self.conv(x)
+
+ x = x.flatten(2).transpose(1, 2)
+ x = self.linear_proj(x)
+ boi = self.boi.expand(x.shape[0], -1, -1)
+ eoi = self.eoi.expand(x.shape[0], -1, -1)
+ x = torch.cat((boi, x, eoi), dim=1)
+ return x
+
+
class CogVLMModel(nn.Module):
"""model."""
@@ -332,7 +617,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -349,10 +633,12 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
+ # vision model
+ self.vision = EVA2CLIPModel(config, dtype=dtype, device=device)
+
# build rotary embedding
emb_type = RopeType.LinearScaling
rope_dim = config.hidden_size // config.num_attention_heads
@@ -371,6 +657,7 @@ def forward(
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
attn_metadata: Any = None,
+ images: torch.Tensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
lang_ids: torch.LongTensor = None,
vision_ids: torch.LongTensor = None,
@@ -379,7 +666,12 @@ def forward(
# token embedding
if inputs_embeds is None:
+ if images is not None:
+ images_features = self.vision(images)
+
inputs_embeds = self.embed_tokens(input_ids)
+ if vision_ids is not None:
+ inputs_embeds[0, vision_ids] = images_features.flatten(0, 1)
hidden_states = inputs_embeds
@@ -416,85 +708,7 @@ def get_input_embeddings(self):
VISION_TOKEN_TYPE = 1
-def get_vision_expert_mask(token_type_ids: torch.LongTensor):
- vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
- vision_token_mask[:, :-1] = (token_type_ids[:, :-1]
- == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:]
- == VISION_TOKEN_TYPE)
- language_token_mask = ~vision_token_mask
- return vision_token_mask, language_token_mask
-
-
-def build_position_ids(x: torch.BoolTensor) -> torch.LongTensor:
- tmp = x.clone()
- # image boi eoi token as LANGUAGE_TOKEN_TYPE
- is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
- is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (
- tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
- is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
- is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (
- tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
- is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
- tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
- # final position ids
- y = torch.zeros_like(x, dtype=torch.long)
- y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
- (tmp[:, 1:] == VISION_TOKEN_TYPE) &
- (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
- y = y.cumsum(dim=-1)
- return y
-
-
-def _get_cogvlm_position_ids(context):
- """get cogvlm position_ids."""
- q_seqlens = context.q_seqlens
- history_lengths = context.kv_seqlens - q_seqlens
- vision_input_info = context.vision_inputs
- position_id_offsets = (vision_input_info.history_image_token_lengths -
- vision_input_info.history_image_nums * 3)
- lang_ids = None
- vis_ids = None
- if context.is_decoding:
- position_ids = history_lengths - position_id_offsets
- else:
- if vision_input_info.input_embeddings is not None and len(
- vision_input_info.input_embeddings) > 0:
- starts = history_lengths - vision_input_info.history_lengths
- ends = starts + q_seqlens
- token_type_ids = vision_input_info.input_embedding_indexing.to(
- torch.int)
- history_position_lengths = (vision_input_info.history_lengths -
- position_id_offsets)
- position_ids_all = (history_position_lengths[:, None] +
- build_position_ids(token_type_ids))
- position_ids = torch.cat([
- pids[s:e]
- for (pids, s, e) in zip(position_ids_all, starts, ends)
- ])
- vision_token_mask_all, _ = get_vision_expert_mask(token_type_ids)
- vision_token_mask = torch.cat([
- masks[s:e]
- for (masks, s, e) in zip(vision_token_mask_all, starts, ends)
- ])
- mask_indexing = torch.arange(vision_token_mask.shape[-1],
- device=vision_token_mask.device)
- vis_ids = mask_indexing[vision_token_mask]
- lang_ids = mask_indexing[~vision_token_mask]
-
- else:
- position_ids = context.attention_mask.long().cumsum(-1) - 1
- position_ids += (history_lengths -
- position_id_offsets).unsqueeze(-1)
- device = position_ids.device
- position_ids_1d = [
- ids[:l] for ids, l in zip(position_ids.cpu(), q_seqlens.cpu())
- ]
- position_ids = torch.cat(position_ids_1d).to(device)
-
- return position_ids, lang_ids, vis_ids
-
-
-class CogVLMForCausalLM(nn.Module, CudaGraphMixin):
+class CogVLMForCausalLM(nn.Module, CudaGraphMixin, DeployModelMixin):
"""ModelForCausalLM."""
packed_modules_mapping = {
@@ -512,6 +726,8 @@ def __init__(self,
super().__init__()
self.config = config
self.ctx_mgr = ctx_mgr
+ # preprocessor
+ self.input_processor = CogVLMInputProcessor(self.config, dtype)
# build model
self.model = CogVLMModel(config, dtype=dtype, device=device)
# build lm_head
@@ -527,6 +743,7 @@ def forward(
position_ids: torch.Tensor,
past_key_values: List[List[torch.Tensor]],
attn_metadata: Any = None,
+ images: torch.Tensor = None,
inputs_embeds: torch.Tensor = None,
lang_ids: torch.LongTensor = None,
vision_ids: torch.LongTensor = None,
@@ -538,6 +755,7 @@ def forward(
position_ids=position_ids,
past_key_values=past_key_values,
attn_metadata=attn_metadata,
+ images=images,
inputs_embeds=inputs_embeds,
lang_ids=lang_ids,
vision_ids=vision_ids,
@@ -561,8 +779,36 @@ def prepare_inputs_for_generation(
"""prepare input."""
# get input_ids, position_ids and attention metadatas
input_ids = context.input_ids
- position_ids, lang_ids, vis_ids = _get_cogvlm_position_ids(context)
- position_ids = position_ids[None]
+
+ # position_ids, lang_ids, vis_ids = _get_cogvlm_position_ids(context)
+ position_ids = context.position_ids
+ lang_ids = None
+ vis_ids = None
+
+ # vision inputs
+ images = None
+ if context.input_multimodals is not None:
+ images = [
+ input_mm.get('image', [])
+ for input_mm in context.input_multimodals
+ ]
+ # flatten batch
+ images = [data for im_data in images for data in im_data]
+ if len(images) == 0:
+ images = None
+
+ if images is not None:
+ image_token_id = images[0].meta['image_token_id']
+ vis_mask = input_ids[0] == image_token_id
+ images = torch.stack([data.data for data in images])
+
+ # get lang_ids
+ vis_range = torch.arange(0,
+ input_ids.size(-1),
+ device=input_ids.device)
+ vis_ids = vis_range[vis_mask]
+ lang_ids = vis_range[~vis_mask]
+
attn_metadata = context.attn_metadata
# process vision embeddings
@@ -581,6 +827,7 @@ def prepare_inputs_for_generation(
position_ids=position_ids,
past_key_values=past_key_values,
attn_metadata=attn_metadata,
+ images=images,
inputs_embeds=inputs_embeds,
lang_ids=lang_ids,
vision_ids=vis_ids,
@@ -597,8 +844,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
- if 'model.vision' in name:
- continue
if 'rotary_emb.inv_freq' in name:
continue
if ('rotary_emb.cos_cached' in name
@@ -607,6 +852,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if self.config.tie_word_embeddings and 'lm_head.weight' in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if '.vision.' in name:
+ continue
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
@@ -620,6 +867,136 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
load_weight(param, q, shard_id='q')
load_weight(param, k, shard_id='k')
load_weight(param, v, shard_id='v')
+ elif '.query_key_value' in name:
+ param = params_dict[name]
+ q, k, v = param.weight_spliter(loaded_weight)
+ load_weight(param, q, shard_id='q')
+ load_weight(param, k, shard_id='k')
+ load_weight(param, v, shard_id='v')
else:
param = params_dict[name]
load_weight(param, loaded_weight)
+
+ def update_model_metas(self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None):
+ """update model meta."""
+ model_metas = context.model_metas
+ input_multimodals = context.input_multimodals
+ if input_multimodals is None:
+ input_imgs = [[] for _ in model_metas]
+ else:
+ input_imgs = []
+ for mm in input_multimodals:
+ if mm is None:
+ input_imgs.append([])
+ else:
+ input_imgs.append(mm.get('image', []))
+
+ config = self.config
+ image_size: int = config.vision_config['image_size']
+ patch_size: int = config.vision_config['patch_size']
+ vision_token_num = ((image_size // patch_size // 2) *
+ (image_size // patch_size // 2) + 2)
+ num_pad = vision_token_num - 3
+
+ batched_num_img_tokens = []
+ new_model_metas = []
+ for meta, imgs in zip(model_metas, input_imgs):
+ if meta is None:
+ num_img_tokens = 0
+ else:
+ num_img_tokens = meta.get('num_img_tokens', 0)
+
+ batched_num_img_tokens.append(num_img_tokens)
+
+ num_img_tokens += num_pad * len(imgs)
+ new_model_metas.append(dict(num_img_tokens=num_img_tokens))
+
+ # prepare cogvlm position_ids
+ q_seqlens = context.q_seqlens
+ position_ids = context.position_ids
+
+ if context.is_decoding or all(len(imgs) == 0 for imgs in input_imgs):
+ num_img_tokens = torch.tensor(batched_num_img_tokens,
+ device=position_ids.device)
+ position_ids -= num_img_tokens[None]
+ else:
+ batched_position_ids = position_ids[0].split(q_seqlens)
+ for pos_ids, num_img_tok, imgs in zip(batched_position_ids,
+ batched_num_img_tokens,
+ input_imgs):
+ pos_ids -= num_img_tok
+ if len(imgs) == 0:
+ continue
+
+ seq_len = pos_ids.size(0)
+ start = pos_ids[0].cpu().item()
+ new_pos_ids = []
+
+ imgs = sorted(imgs, key=lambda img: img.start)
+ for img in imgs:
+ img_pad_pos = img.start + 1 - num_img_tok
+ num_pad = img.end - img.start - 2
+ new_pos_ids += list(range(start, img_pad_pos))
+ new_pos_ids += [img_pad_pos] * num_pad
+ start = img_pad_pos + 1
+ num_img_tok += num_pad
+
+ remain = seq_len - len(new_pos_ids)
+ new_pos_ids += list(range(start, start + remain))
+
+ new_pos_ids = pos_ids.new_tensor(new_pos_ids)
+ pos_ids[:] = new_pos_ids
+
+ position_ids = torch.cat(batched_position_ids)[None]
+ context.position_ids = position_ids
+
+ return new_model_metas
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+
+class CogVLMInputProcessor(BaseModelInputProcessor):
+ """input processor."""
+
+ def __init__(self, config: PretrainedConfig, dtype) -> None:
+ self.config = config
+ self.dtype = dtype
+ image_size: int = config.vision_config['image_size']
+ patch_size: int = config.vision_config['patch_size']
+ self.vision_token_num = ((image_size // patch_size // 2) *
+ (image_size // patch_size // 2) + 2)
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_multimodals=None,
+ **kwargs) -> PreprocessInputResult:
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values'].to(self.dtype)
+ offset = input_mm['offset']
+ image_token_id = input_mm.get('image_token_id', 0)
+ num_pad = input_mm['image_tokens']
+ if isinstance(num_pad, torch.Tensor):
+ num_pad = num_pad.item()
+
+ mm_data = MultiModalTensor(
+ data=pixel_values,
+ start=offset,
+ end=offset + num_pad,
+ meta=dict(image_token_id=image_token_id))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
diff --git a/lmdeploy/pytorch/models/dbrx.py b/lmdeploy/pytorch/models/dbrx.py
index e71ff17fe9..7e61fd317d 100644
--- a/lmdeploy/pytorch/models/dbrx.py
+++ b/lmdeploy/pytorch/models/dbrx.py
@@ -9,7 +9,7 @@
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, LayerNorm,
RopeType, build_rotary_embedding)
from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear
-from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK
+from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .utils.cudagraph import CudaGraphMixin
@@ -165,7 +165,7 @@ def __init__(self,
act_fn_name = ffn_act_fn.get('name', None)
assert act_fn_name == 'silu'
- self.mlp = FusedMoE(
+ self.mlp = build_fused_moe(
hidden_size,
ffn_hidden_size,
moe_num_experts,
@@ -522,7 +522,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if '.experts' in name:
loaded_weight = loaded_weight.unflatten(0, (num_experts, -1))
if '.w1' in name:
- name = name.replace('.w1', '.gate_up_weights')
+ name = name.replace('.w1', '.gate_up.weight')
param = params_dict[name]
for exp_id in range(num_experts):
weight = loaded_weight[exp_id]
@@ -531,7 +531,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_id=exp_id,
shard_id='gate')
elif '.v1' in name:
- name = name.replace('.v1', '.gate_up_weights')
+ name = name.replace('.v1', '.gate_up.weight')
param = params_dict[name]
for exp_id in range(num_experts):
weight = loaded_weight[exp_id]
@@ -540,7 +540,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_id=exp_id,
shard_id='up')
elif '.w2' in name:
- name = name.replace('.w2', '.down_weights')
+ name = name.replace('.w2', '.down.weight')
param = params_dict[name]
for exp_id in range(num_experts):
weight = loaded_weight[exp_id].t()
diff --git a/lmdeploy/pytorch/models/deepseek.py b/lmdeploy/pytorch/models/deepseek.py
index 5742baeee5..09c0b74fcc 100644
--- a/lmdeploy/pytorch/models/deepseek.py
+++ b/lmdeploy/pytorch/models/deepseek.py
@@ -12,7 +12,7 @@
SiluAndMul, build_rotary_embedding)
from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear,
build_qkv_proj, build_rowwise_linear)
-from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK
+from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .utils.cudagraph import CudaGraphMixin
@@ -135,7 +135,7 @@ def __init__(self,
self.softmax_topk = SoftmaxTopK(self.top_k)
- self.experts = FusedMoE(
+ self.experts = build_fused_moe(
self.hidden_dim,
self.ffn_dim,
self.num_experts,
@@ -265,12 +265,10 @@ def __init__(self,
device=device)
# build attention layer norm
- self.post_attention_layernorm = RMSNorm(
- config.hidden_size,
- config.rms_norm_eps,
- quant_config=quantization_config,
- dtype=dtype,
- device=device)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ config.rms_norm_eps,
+ dtype=dtype,
+ device=device)
def forward(
self,
@@ -315,7 +313,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -332,7 +329,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
@@ -528,14 +524,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
num_experts = self.config.n_routed_experts
expert_params_mapping = []
for exp_id in range(num_experts):
- gate_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.gate_proj.weight', exp_id,
- 'gate')
- up_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.up_proj.weight', exp_id, 'up')
- down_param = ('.experts.down_weights',
- f'.experts.{exp_id}.down_proj.weight', exp_id,
- 'down')
+ gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj',
+ exp_id, 'gate')
+ up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj',
+ exp_id, 'up')
+ down_param = ('.experts.down', f'.experts.{exp_id}.down_proj',
+ exp_id, 'down')
expert_params_mapping += [gate_param, up_param, down_param]
params_dict = dict(self.named_parameters())
diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py
index 34debae229..b69ae6650d 100644
--- a/lmdeploy/pytorch/models/deepseek_v2.py
+++ b/lmdeploy/pytorch/models/deepseek_v2.py
@@ -4,6 +4,7 @@
import torch
import torch.distributed as dist
+import torch.nn.functional as F
from torch import nn
from lmdeploy.pytorch.distributed import get_world_rank
@@ -13,7 +14,7 @@
from lmdeploy.pytorch.nn.linear import (build_colwise_linear,
build_merged_colwise_linear,
build_rowwise_linear)
-from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK
+from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.nn.rotary_embedding import YarnParameters
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
@@ -81,7 +82,7 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
- quantization_config = None
+ quantization_config = getattr(config, 'quantization_config', None)
self.q_lora_rank = config.q_lora_rank
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
@@ -90,6 +91,9 @@ def __init__(self,
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
+ num_key_value_heads = getattr(config, 'num_key_value_heads', 1)
if self.q_lora_rank is None:
self.q_proj = build_colwise_linear(
@@ -99,6 +103,7 @@ def __init__(self,
dtype=dtype,
device=device,
is_tp=True,
+ quant_config=quantization_config,
)
else:
self.q_a_proj = build_colwise_linear(
@@ -108,6 +113,7 @@ def __init__(self,
dtype=dtype,
device=device,
is_tp=False,
+ quant_config=quantization_config,
)
self.q_a_layernorm = RMSNorm(config.q_lora_rank,
1e-6,
@@ -121,6 +127,7 @@ def __init__(self,
dtype=dtype,
device=device,
is_tp=True,
+ quant_config=quantization_config,
)
self.kv_a_proj_with_mqa = build_colwise_linear(
@@ -130,6 +137,7 @@ def __init__(self,
dtype=dtype,
device=device,
is_tp=False,
+ quant_config=quantization_config,
)
self.kv_a_layernorm = RMSNorm(config.kv_lora_rank,
1e-6,
@@ -157,10 +165,9 @@ def __init__(self,
self.num_heads,
config.kv_lora_rank + self.qk_rope_head_dim,
scale=self.softmax_scale,
- num_kv_heads=1,
+ num_kv_heads=num_key_value_heads,
v_head_size=config.kv_lora_rank,
- replicate_kv=True,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
self.vc = DeepseekV2BMM(self.num_heads,
config.kv_lora_rank,
@@ -174,6 +181,7 @@ def __init__(self,
dtype=dtype,
device=device,
is_tp=True,
+ quant_config=quantization_config,
)
def _q_proj(self, hidden_states, num_heads: int, nope_size: int,
@@ -270,6 +278,104 @@ def forward(
return attn_output
+class MoEGate(nn.Module):
+ """Deepseek Gate."""
+
+ def __init__(self,
+ config: Any,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.top_k = config.num_experts_per_tok
+ self.n_routed_experts = config.n_routed_experts
+ self.routed_scaling_factor = config.routed_scaling_factor
+ self.scoring_func = config.scoring_func
+ self.alpha = config.aux_loss_alpha
+ self.seq_aux = config.seq_aux
+ self.topk_method = config.topk_method
+ self.n_group = config.n_group
+ self.topk_group = config.topk_group
+ self.norm_topk_prob = config.norm_topk_prob
+ self.renormalize = self.top_k > 1 and self.norm_topk_prob
+
+ # topk selection algorithm
+ self.norm_topk_prob = config.norm_topk_prob
+ self.gating_dim = config.hidden_size
+ self.weight = nn.Parameter(
+ torch.empty((self.n_routed_experts, self.gating_dim),
+ dtype=dtype,
+ device=device))
+ if self.topk_method == 'noaux_tc':
+ self.e_score_correction_bias = nn.Parameter(
+ torch.empty((self.n_routed_experts, ),
+ dtype=dtype,
+ device=device))
+ self.softmax_topk = SoftmaxTopK(self.top_k)
+
+ def _compute_scores(self, logits: torch.Tensor):
+ """compute scores."""
+ if self.scoring_func == 'softmax':
+ scores = logits.softmax(dim=-1, dtype=torch.float32)
+ elif self.scoring_func == 'sigmoid':
+ scores = logits.sigmoid()
+ else:
+ raise NotImplementedError('insupportable scoring function '
+ f'for MoE gating: {self.scoring_func}')
+ return scores
+
+ def forward(self, hidden_states: torch.Tensor):
+ """forward."""
+ sequence_length, hidden_dim = hidden_states.shape
+ router_logits = F.linear(hidden_states, self.weight)
+
+ if self.topk_method == 'greedy':
+ topk_weight, topk_idx = self.softmax_topk(router_logits)
+ elif self.topk_method == 'group_limited_greedy':
+ scores = self._compute_scores(router_logits)
+ grouped_logits = scores.unflatten(-1, (self.n_group, -1))
+ group_scores = (grouped_logits.max(-1).values)
+ group_idx = torch.topk(group_scores,
+ k=self.topk_group,
+ dim=-1,
+ sorted=False)[1] # [n, top_k_group]
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
+ group_mask = ~group_mask.bool()[..., None]
+ grouped_logits = grouped_logits.masked_fill(group_mask, 0.0)
+ scores = grouped_logits.flatten(1, 2)
+ topk_weight, topk_idx = self.softmax_topk(scores)
+ elif self.topk_method == 'noaux_tc':
+ scores = self._compute_scores(router_logits)
+ scores_for_choice = scores.view(
+ sequence_length, -1) + self.e_score_correction_bias[None]
+ group_scores = (scores_for_choice.view(
+ sequence_length, self.n_group,
+ -1).topk(2, dim=-1)[0].sum(dim=-1)) # [n, n_group]
+ group_idx = torch.topk(group_scores,
+ k=self.topk_group,
+ dim=-1,
+ sorted=False)[1] # [n, top_k_group]
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
+ score_mask = (group_mask.unsqueeze(-1).expand(
+ sequence_length, self.n_group,
+ self.n_routed_experts // self.n_group).reshape(
+ sequence_length, -1)) # [n, e]
+ tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(),
+ 0.0) # [n, e]
+ _, topk_idx = torch.topk(tmp_scores,
+ k=self.top_k,
+ dim=-1,
+ sorted=False)
+ topk_weight = scores.gather(1, topk_idx)
+ else:
+ raise RuntimeError(f'Unsupported topk_method: {self.topk_method}')
+ if not self.renormalize:
+ topk_weight = topk_weight * self.routed_scaling_factor
+ return topk_weight, topk_idx
+
+
class DeepseekV2MoE(nn.Module):
"""Deepseek v2 MoE."""
@@ -278,6 +384,7 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
self.hidden_dim = config.hidden_size
self.ffn_dim = config.moe_intermediate_size
self.num_experts = config.n_routed_experts
@@ -289,18 +396,9 @@ def __init__(self,
self.n_group = config.n_group
self.topk_group = config.topk_group
- self.gate = build_rowwise_linear(
- self.hidden_dim,
- self.num_experts,
- bias=False,
- dtype=dtype,
- device=device,
- is_tp=False,
- )
-
- self.softmax_topk = SoftmaxTopK(self.top_k)
+ self.gate = MoEGate(config, dtype=dtype, device=device)
- self.experts = FusedMoE(
+ self.experts = build_fused_moe(
self.hidden_dim,
self.ffn_dim,
self.num_experts,
@@ -309,6 +407,7 @@ def __init__(self,
dtype=dtype,
device=device,
all_reduce=False,
+ quant_config=quantization_config,
)
self.shared_experts = None
@@ -333,27 +432,8 @@ def forward(self, hidden_states: torch.Tensor):
"""forward."""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
- router_logits = self.gate(hidden_states)
+ topk_weights, topk_ids = self.gate(hidden_states)
- if self.topk_method == 'greedy':
- topk_weights, topk_ids = self.softmax_topk(router_logits)
- elif self.topk_method == 'group_limited_greedy':
- grouped_logits = router_logits.unflatten(-1, (self.n_group, -1))
- group_scores = (grouped_logits.max(-1).values)
- group_idx = torch.topk(group_scores,
- k=self.topk_group,
- dim=-1,
- sorted=False)[1] # [n, top_k_group]
- group_mask = torch.zeros_like(group_scores) # [n, n_group]
- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
- group_mask = ~group_mask.bool()[..., None]
- grouped_logits = grouped_logits.masked_fill(group_mask, 0.0)
- router_logits = grouped_logits.flatten(1, 2)
- topk_weights, topk_ids = self.softmax_topk(router_logits)
- else:
- raise RuntimeError(f'Unsupported topk_method: {self.topk_method}')
- if not self.renormalize:
- topk_weights = topk_weights * self.routed_scaling_factor
out_states = self.experts(
hidden_states,
topk_weights,
@@ -450,12 +530,10 @@ def __init__(self,
device=device)
# build attention layer norm
- self.post_attention_layernorm = RMSNorm(
- config.hidden_size,
- config.rms_norm_eps,
- quant_config=quantization_config,
- dtype=dtype,
- device=device)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ config.rms_norm_eps,
+ dtype=dtype,
+ device=device)
def forward(
self,
@@ -572,7 +650,6 @@ def forward(
cos, sin = cos[0], sin[0]
rotary_pos_emb = (cos, sin)
for idx, decoder_layer in enumerate(self.layers):
-
past_key_value = past_key_values[idx]
hidden_states, residual = decoder_layer(
hidden_states,
@@ -601,6 +678,8 @@ def __init__(self,
device: torch.device = None):
super().__init__()
self.config = config
+ self.quantization_config = getattr(config, 'quantization_config', None)
+ self.dtype = dtype
self.ctx_mgr = ctx_mgr
self.model = DeepseekV2Model(config, dtype=dtype, device=device)
# build lm_head
@@ -609,6 +688,7 @@ def __init__(self,
bias=False,
dtype=dtype,
device=device)
+ self._load_buffers = dict()
def forward(
self,
@@ -692,40 +772,99 @@ def __update_pe(weight, head_dim: int, pe_dim_offset: int):
weight = weight.flatten(0, 1)
return weight
+ def __load_kcvc(name: str, weight: torch.Tensor):
+ """load kc and vc from weight."""
+ config = self.config
+ v_head_dim = config.v_head_dim
+ qk_nope_head_dim = config.qk_nope_head_dim
+ w_kc, w_vc = weight.unflatten(
+ 0, (-1, qk_nope_head_dim + v_head_dim)).split(
+ [qk_nope_head_dim, v_head_dim], dim=1)
+ w_vc = w_vc.transpose(1, 2).contiguous()
+ kc_param_name = name.replace('.kv_b_proj', '.kc')
+ param_kc = params_dict[kc_param_name]
+ load_weight(param_kc, w_kc)
+ vc_param_name = name.replace('.kv_b_proj', '.vc')
+ param_vc = params_dict[vc_param_name]
+ load_weight(param_vc, w_vc)
+
+ def __dequant_weight(weight: torch.Tensor, scale: torch.Tensor,
+ dtype: torch.dtype):
+ """dequant weight."""
+ dim_w0, dim_w1 = weight.shape
+ dim_s0, dim_s1 = scale.shape
+ assert dim_w0 % dim_s0 == 0
+ assert dim_w1 % dim_s1 == 0
+ group0 = dim_w0 // dim_s0
+ group1 = dim_w1 // dim_s1
+ weight = weight.reshape(dim_s0, group0, dim_s1, group1)
+ scale = scale.reshape(dim_s0, 1, dim_s1, 1)
+ weight = weight.to(scale.dtype) * scale
+ weight = weight.to(dtype)
+ weight = weight.reshape(dim_w0, dim_w1)
+ return weight
+
+ def __load_kcvc_blocked_fp8(name: str, loaded_weight: torch.Tensor):
+ """dequant weight."""
+ if name.endswith('.weight'):
+ weight_name = name
+ scale_name = name.replace('.weight', '.scale')
+ elif name.endswith('.scale'):
+ weight_name = name.replace('.scale', '.weight')
+ scale_name = name
+ self._load_buffers[name] = loaded_weight
+ if (weight_name in self._load_buffers
+ and scale_name in self._load_buffers):
+ weight = self._load_buffers.pop(weight_name)
+ scale = self._load_buffers.pop(scale_name)
+ kc_param_name = weight_name.replace('.kv_b_proj', '.kc')
+ dtype = params_dict[kc_param_name].dtype
+ weight = __dequant_weight(weight, scale, dtype)
+ __load_kcvc(weight_name, weight)
+
for (mod_name, head_dim, pe_dim_offset) in update_pe_mapping:
if mod_name not in name:
continue
- weight = __update_pe(loaded_weight, head_dim, pe_dim_offset)
+ if name.endswith('.scale'):
+ weight = loaded_weight
+ else:
+ weight = __update_pe(loaded_weight, head_dim, pe_dim_offset)
param = params_dict[name]
load_weight(param, weight)
break
else:
if '.kv_b_proj' in name:
- config = self.config
- v_head_dim = config.v_head_dim
- qk_nope_head_dim = config.qk_nope_head_dim
- w_kc, w_vc = loaded_weight.unflatten(
- 0, (-1, qk_nope_head_dim + v_head_dim)).split(
- [qk_nope_head_dim, v_head_dim], dim=1)
- w_vc = w_vc.transpose(1, 2).contiguous()
- kc_param_name = name.replace('.kv_b_proj', '.kc')
- param_kc = params_dict[kc_param_name]
- load_weight(param_kc, w_kc)
- vc_param_name = name.replace('.kv_b_proj', '.vc')
- param_vc = params_dict[vc_param_name]
- load_weight(param_vc, w_vc)
+ quantization_config = self.quantization_config
+ quant_method = None
+ if quantization_config is not None:
+ quant_method = quantization_config.get('quant_method')
+
+ if quant_method == 'fp8':
+ # update blocked fp8 weight
+ __load_kcvc_blocked_fp8(name, loaded_weight)
+ else:
+ __load_kcvc(name, loaded_weight)
else:
param = params_dict[name]
load_weight(param, loaded_weight)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""load weights."""
+
+ def __skip_nextn(name, nextn_keys):
+ for nextn_key in nextn_keys:
+ if nextn_key in name:
+ return True
+ return False
+
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
('.gate_up_proj', '.gate_proj', 0),
('.gate_up_proj', '.up_proj', 1),
]
+ scale_suffix = '.weight_scale_inv'
+
config = self.config
qk_rope_head_dim = config.qk_rope_head_dim
kv_lora_rank = config.kv_lora_rank
@@ -739,16 +878,23 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
num_experts = self.config.n_routed_experts
expert_params_mapping = []
for exp_id in range(num_experts):
- gate_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.gate_proj.weight', exp_id,
- 'gate')
- up_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.up_proj.weight', exp_id, 'up')
- down_param = ('.experts.down_weights',
- f'.experts.{exp_id}.down_proj.weight', exp_id,
- 'down')
+ gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj',
+ exp_id, 'gate')
+ up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj',
+ exp_id, 'up')
+ down_param = ('.experts.down', f'.experts.{exp_id}.down_proj',
+ exp_id, 'down')
expert_params_mapping += [gate_param, up_param, down_param]
+ num_hidden_layers = self.config.num_hidden_layers
+
+ num_nextn_predict_layers = getattr(self.config,
+ 'num_nextn_predict_layers', 1)
+ nextn_keys = [
+ f'.layers.{num_hidden_layers+i}'
+ for i in range(num_nextn_predict_layers)
+ ]
+
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if 'rotary_emb.inv_freq' in name:
@@ -756,8 +902,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if ('rotary_emb.cos_cached' in name
or 'rotary_emb.sin_cached' in name):
continue
+ if '.layers' in name:
+ # skip nextn
+ if __skip_nextn(name, nextn_keys):
+ continue
if self.config.tie_word_embeddings and 'lm_head.weight' in name:
continue
+ if name.endswith(scale_suffix):
+ name = name[:-len(scale_suffix)] + '.scale'
if '.experts' in name:
self._load_weight_experts(
name,
diff --git a/lmdeploy/pytorch/models/falcon.py b/lmdeploy/pytorch/models/falcon.py
index 8f8659dc5e..2d2edb9f49 100644
--- a/lmdeploy/pytorch/models/falcon.py
+++ b/lmdeploy/pytorch/models/falcon.py
@@ -31,34 +31,31 @@ def __init__(self,
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
- self.num_kv_heads = self.num_attention_heads
+ self.num_kv_heads = getattr(config, 'num_kv_heads',
+ config.num_attention_heads)
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
self.head_size = (self.hidden_size // config.num_attention_heads)
- self.multi_query_attention = config.multi_query
- if self.multi_query_attention:
- self.num_kv_heads = 1
self.query_key_value = build_qkv_proj(
config.hidden_size,
num_q_heads=self.num_attention_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
bias=config.bias,
- replicate_kv=self.multi_query_attention,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
# apply rotary
self.apply_rotary_pos_emb = ApplyRotaryEmb()
self.rotary = config.rotary
# attention
- self.attn_fwd = Attention(
- self.num_attention_heads,
- self.head_size,
- num_kv_heads=self.num_kv_heads,
- alibi=config.alibi,
- )
+ self.attn_fwd = Attention(self.num_attention_heads,
+ self.head_size,
+ num_kv_heads=self.num_kv_heads,
+ alibi=config.alibi)
# o_proj
self.dense = build_rowwise_linear(self.hidden_size,
diff --git a/lmdeploy/pytorch/models/gemma.py b/lmdeploy/pytorch/models/gemma.py
index ca36f15651..1f24206b16 100644
--- a/lmdeploy/pytorch/models/gemma.py
+++ b/lmdeploy/pytorch/models/gemma.py
@@ -31,7 +31,8 @@ def __init__(self,
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
head_dim = config.head_dim
-
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
# packed qkv
self.qkv_proj = build_qkv_proj(
hidden_size,
@@ -42,7 +43,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
# rotary embedding
self.apply_rotary_pos_emb = ApplyRotaryEmb()
@@ -262,7 +263,6 @@ def __init__(self,
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -279,7 +279,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/internlm.py b/lmdeploy/pytorch/models/internlm.py
index 99c622e4ac..fdee716b4a 100644
--- a/lmdeploy/pytorch/models/internlm.py
+++ b/lmdeploy/pytorch/models/internlm.py
@@ -28,7 +28,8 @@ def __init__(self,
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
-
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
# packed qkv
self.qkv_proj = build_qkv_proj(
hidden_size,
@@ -39,7 +40,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
# rotary embedding
self.apply_rotary_pos_emb = ApplyRotaryEmb()
diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py
index 6cbc2ccff3..52f51a3ad1 100644
--- a/lmdeploy/pytorch/models/internlm2.py
+++ b/lmdeploy/pytorch/models/internlm2.py
@@ -28,7 +28,8 @@ def __init__(self,
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
head_dim = hidden_size // num_heads
-
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
# packed qkv
self.wqkv = build_qkv_proj(
hidden_size,
@@ -39,6 +40,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
+ num_replicate_kv_heads=num_replicate_kv_heads,
)
# rotary embedding
@@ -219,7 +221,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.tok_embeddings = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -239,7 +240,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
@@ -395,6 +395,32 @@ def prepare_inputs_for_generation(
inputs_embeds=inputs_embeds,
)
+ def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
+ adapter_id: int):
+ """load lora weights."""
+
+ from lmdeploy.pytorch.adapter.adapter import load_lora_weights
+
+ num_heads = self.config.num_attention_heads
+ num_key_value_heads = self.config.num_key_value_heads
+ hidden_size = self.config.hidden_size
+ head_dim = hidden_size // num_heads
+ group_size = num_heads // num_key_value_heads
+
+ def _rearange_wqkv(weights):
+ for name, loaded_weight in weights:
+ if 'wqkv.lora_B' in name:
+ loaded_weight = loaded_weight.unflatten(
+ 0, (-1, 2 + group_size, head_dim))
+ q = loaded_weight[:, :-2].flatten(0, 2)
+ k = loaded_weight[:, -2].flatten(0, 1)
+ v = loaded_weight[:, -1].flatten(0, 1)
+ loaded_weight = torch.cat([q, k, v], dim=0)
+ yield name, loaded_weight
+
+ weights_iter = _rearange_wqkv(weights)
+ load_lora_weights(self, weights_iter, adapter_id)
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""load weights."""
# modify from vllm
diff --git a/lmdeploy/pytorch/models/internlm2_ve.py b/lmdeploy/pytorch/models/internlm2_ve.py
index b1a2329597..c10faa5f5d 100644
--- a/lmdeploy/pytorch/models/internlm2_ve.py
+++ b/lmdeploy/pytorch/models/internlm2_ve.py
@@ -105,7 +105,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.tok_embeddings = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -125,7 +124,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py
index 70dd8f2159..5fccd627e5 100644
--- a/lmdeploy/pytorch/models/internvl.py
+++ b/lmdeploy/pytorch/models/internvl.py
@@ -1,17 +1,315 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Any, Iterable, List, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
+import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig
+from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor,
+ PreprocessInputResult)
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
+from lmdeploy.pytorch.nn import LayerNorm, RMSNorm
+from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_qkv_proj,
+ build_rowwise_linear)
+from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .patch import build_model_from_hf_config
from .utils.cudagraph import CudaGraphMixin
+from .utils.model import DeployModelMixin
-class InternVLChatModel(nn.Module, CudaGraphMixin):
+class InternVisionEmbeddings(nn.Module):
+ """intern vision embedding."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(
+ torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device), )
+
+ self.patch_embedding = nn.Conv2d(in_channels=3,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ dtype=dtype,
+ device=device)
+
+ self.num_patches = (self.image_size // self.patch_size)**2
+ self.num_positions = self.num_patches + 1
+
+ self.position_embedding = nn.Parameter(
+ torch.empty(1,
+ self.num_positions,
+ self.embed_dim,
+ dtype=dtype,
+ device=device))
+
+ def _get_pos_embed(self, pos_embed, H, W):
+ target_dtype = pos_embed.dtype
+ pos_embed = pos_embed.float().reshape(
+ 1, self.image_size // self.patch_size,
+ self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
+ pos_embed = F.interpolate(pos_embed,
+ size=(H, W),
+ mode='bicubic',
+ align_corners=False).reshape(
+ 1, -1, H * W).permute(0, 2,
+ 1).to(target_dtype)
+ return pos_embed
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(
+ pixel_values) # shape = [*, channel, width, height]
+ batch_size, _, height, width = patch_embeds.shape
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+ class_embeds = self.class_embedding.expand(batch_size, 1,
+ -1).to(target_dtype)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ position_embedding = torch.cat([
+ self.position_embedding[:, :1, :],
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height,
+ width)
+ ],
+ dim=1)
+ embeddings = embeddings + position_embedding.to(target_dtype)
+ return embeddings
+
+
+NORM2FN = {
+ 'rms_norm': RMSNorm,
+ 'layer_norm': LayerNorm,
+}
+
+
+class InternAttention(nn.Module):
+ """intern vl attention."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+
+ self.qkv = build_qkv_proj(
+ self.embed_dim,
+ num_q_heads=self.num_heads,
+ num_kv_heads=self.num_heads,
+ head_size=self.head_dim,
+ bias=config.qkv_bias,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ )
+
+ self.qk_normalization = config.qk_normalization
+
+ if self.qk_normalization:
+ self.q_norm = RMSNorm(
+ self.embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device,
+ tp=True,
+ align=self.head_dim,
+ )
+ self.k_norm = RMSNorm(
+ self.embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device,
+ tp=True,
+ align=self.head_dim,
+ )
+
+ self.scale = self.head_dim**-0.5
+
+ # o_proj
+ self.proj = build_rowwise_linear(self.embed_dim,
+ self.embed_dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True,
+ tp_align_size=self.head_dim)
+
+ def forward(self, hidden_states):
+ """forward."""
+
+ # qkv proj
+ qkv_states = self.qkv(hidden_states)
+ q, k, v = self.qkv.split_qkv(qkv_states)
+
+ if self.qk_normalization:
+ q_shape = q.shape
+ q = self.q_norm(q.flatten(-2, -1)).view(q_shape)
+ k = self.k_norm(k.flatten(-2, -1)).view(q_shape)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
+
+ # o proj
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.flatten(-2, -1)
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class InternMLP(nn.Module):
+ """intern vl mlp."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ from transformers.activations import ACT2FN
+ self.config = config
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.act = ACT2FN[config.hidden_act]
+
+ self.fc1 = build_colwise_linear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+
+ self.fc2 = build_rowwise_linear(config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class InternVisionEncoderLayer(nn.Module):
+ """intern vision encoder layer."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.norm_type = getattr(config, 'norm_type', 'rms_norm')
+
+ self.attn = InternAttention(config, dtype=dtype, device=device)
+ self.mlp = InternMLP(config, dtype=dtype, device=device)
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+
+ self.ls1 = nn.Parameter(
+ torch.empty(self.embed_dim, dtype=dtype, device=device))
+ self.ls2 = nn.Parameter(
+ torch.empty(self.embed_dim, dtype=dtype, device=device))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ):
+ """forward."""
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
+
+ hidden_states = hidden_states + self.mlp(
+ self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2
+
+ return hidden_states
+
+
+class InternVisionEncoder(nn.Module):
+ """intern vision encoder."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([
+ InternVisionEncoderLayer(config, dtype=dtype, device=device)
+ for idx in range(config.num_hidden_layers)
+ ])
+
+ def forward(
+ self,
+ inputs_embeds,
+ ):
+ """forward."""
+ hidden_states = inputs_embeds
+ for _, encoder_layer in enumerate(self.layers):
+ layer_outputs = encoder_layer(hidden_states, )
+ hidden_states = layer_outputs
+ return hidden_states
+
+
+class InternVisionModel(nn.Module):
+ """intern vision model."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+
+ self.embeddings = InternVisionEmbeddings(config,
+ dtype=dtype,
+ device=device)
+ self.encoder = InternVisionEncoder(config, dtype=dtype, device=device)
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ ):
+ """forward."""
+ assert pixel_values.dim() == 4
+ hidden_states = self.embeddings(pixel_values)
+
+ encoder_outputs = self.encoder(inputs_embeds=hidden_states)
+ last_hidden_state = encoder_outputs
+
+ return last_hidden_state
+
+
+class InternVLChatModel(nn.Module, DeployModelMixin, CudaGraphMixin):
def __init__(self,
config: PretrainedConfig,
@@ -21,31 +319,106 @@ def __init__(self,
super().__init__()
self.config = config
self.ctx_mgr = ctx_mgr
+ self.select_layer = config.select_layer
+
llm_config = config.llm_config
+ self.llm_arch_name = llm_config.architectures[0]
+ self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
+
+ vision_config = config.vision_config
+ if self.is_mono:
+ from .internvl_patch import InternVisionPatchModel
+ self.vision_model = InternVisionPatchModel(
+ vision_config,
+ dtype=dtype,
+ device=device,
+ )
+ else:
+ self.vision_model = InternVisionModel(vision_config,
+ dtype=dtype,
+ device=device)
+
self.language_model = build_model_from_hf_config(llm_config,
dtype=dtype,
device=device)
- self.llm_arch_name = llm_config.architectures[0]
+ vit_hidden_size = config.vision_config.hidden_size
+ llm_hidden_size = config.llm_config.hidden_size
+ self.downsample_ratio = config.downsample_ratio
+ self.mlp1 = nn.Sequential(
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2,
+ dtype=dtype,
+ device=device),
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
+ llm_hidden_size,
+ dtype=dtype,
+ device=device), nn.GELU(),
+ nn.Linear(llm_hidden_size,
+ llm_hidden_size,
+ dtype=dtype,
+ device=device))
# for Mono-InternVL
- self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
if self.is_mono:
assert dtype != torch.float16, (
'Currently Mono-InternVL does not support FP16 due to'
'numerical instability. Please use BF16 instead.')
+ self.input_processor = InternVLInputProcessor(self.config, dtype)
+
+ def pixel_shuffle(self, x, scale_factor=0.5):
+ n, w, h, c = x.size()
+ # N, W, H, C --> N, W, H * scale, C // scale
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
+ x = x.permute(0, 2, 1, 3).contiguous()
+ # N, H * scale, W, C // scale -->
+ # N, H * scale, W * scale, C // (scale ** 2)
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
+ int(c / (scale_factor * scale_factor)))
+ x = x.permute(0, 2, 1, 3).contiguous()
+ return x
+
+ def extract_feature(self, pixel_values):
+ """extract vision feature."""
+ assert self.select_layer == -1
+ vit_embeds = self.vision_model(pixel_values)
+ if self.is_mono:
+ if int(vit_embeds.shape[1]**0.5)**2 != vit_embeds.shape[1]:
+ vit_embeds = vit_embeds[:, 1:, :]
+ else:
+ vit_embeds = vit_embeds[:, 1:, :]
+
+ h = w = int(vit_embeds.shape[1]**0.5)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
+ vit_embeds = self.pixel_shuffle(vit_embeds,
+ scale_factor=self.downsample_ratio)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
+ vit_embeds.shape[-1])
+ vit_embeds = self.mlp1(vit_embeds)
+ return vit_embeds
+
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: List[List[torch.Tensor]],
attn_metadata: Any = None,
+ pixel_values: torch.Tensor = None,
+ image_mask: torch.Tensor = None,
inputs_embeds: torch.Tensor = None,
vision_embedding_indexing: torch.Tensor = None,
text_embedding_indexing: torch.Tensor = None,
**kwargs,
):
+ if inputs_embeds is None and pixel_values is not None:
+ # extract feature
+ vit_embeds = self.extract_feature(pixel_values)
+ lang_embeds = self.language_model.get_input_embeddings()(input_ids)
+ lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds)
+
+ inputs_embeds = lang_embeds
+
if self.is_mono:
return self.language_model.forward(
input_ids=input_ids,
@@ -80,11 +453,38 @@ def prepare_inputs_for_generation(
input_ids = context.input_ids
position_ids = context.position_ids
attn_metadata = context.attn_metadata
- # get inputs from context
vision_embeddings = context.input_embeddings
- vision_embedding_indexing = context.input_embedding_indexing
+ vision_embedding_indexing = None
+
+ # vision inputs
+ pixel_values = None
+ image_mask = None
+ if context.input_multimodals is not None:
+ pixel_values = [
+ input_mm.get('image', [])
+ for input_mm in context.input_multimodals
+ ]
+ # flatten batch
+ pixel_values = [
+ data for im_data in pixel_values for data in im_data
+ ]
+ if len(pixel_values) > 0:
+ image_token_id = pixel_values[0].meta['image_token_id']
+ image_mask = input_ids == image_token_id
+ pixel_values = torch.cat([data.data for data in pixel_values])
+ else:
+ pixel_values = None
+ image_mask = None
+ if self.is_mono and pixel_values is not None:
+ vision_embedding_indexing = torch.arange(input_ids.shape[1],
+ device=input_ids.device)
+ vision_embedding_indexing = vision_embedding_indexing[
+ image_mask[0]]
+
+ # get inputs from context
if vision_embeddings is not None and len(vision_embeddings) > 0:
+ vision_embedding_indexing = context.input_embedding_indexing
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds[:,
@@ -104,6 +504,8 @@ def prepare_inputs_for_generation(
position_ids=position_ids,
past_key_values=past_key_values,
attn_metadata=attn_metadata,
+ pixel_values=pixel_values,
+ image_mask=image_mask,
inputs_embeds=inputs_embeds,
vision_embedding_indexing=vision_embedding_indexing,
text_embedding_indexing=text_embedding_indexing,
@@ -114,18 +516,96 @@ def prepare_inputs_for_generation(
position_ids=position_ids,
past_key_values=past_key_values,
attn_metadata=attn_metadata,
+ pixel_values=pixel_values,
+ image_mask=image_mask,
inputs_embeds=inputs_embeds,
)
+ def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
+ adapter_id: int):
+ """load lora weights."""
+
+ if hasattr(self.language_model, 'load_lora_weights'):
+ return self.language_model.load_lora_weights(weights, adapter_id)
+ else:
+ from lmdeploy.pytorch.adapter.adapter import load_lora_weights
+
+ return load_lora_weights(weights, adapter_id)
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""load weights."""
- prefix_length = len('language_model.')
+ lang_prefix = 'language_model.'
+ params_dict = dict(self.named_parameters())
+ for name, loaded_weight in weights:
+ if name.startswith(lang_prefix):
+ continue
+
+ if 'qkv' in name:
+ param = params_dict[name]
+ q, k, v = param.weight_spliter(loaded_weight)
+ load_weight(param, q, shard_id='q')
+ load_weight(param, k, shard_id='k')
+ load_weight(param, v, shard_id='v')
+ else:
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
+
+ lang_prefix_length = len(lang_prefix)
new_weights = dict()
for key, val in weights:
- if not key.startswith('language_model.'):
+ if not key.startswith(lang_prefix):
continue
- new_key = key[prefix_length:]
+ new_key = key[lang_prefix_length:]
new_weights[new_key] = val
self.language_model.load_weights(new_weights.items())
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+
+class InternVLInputProcessor(BaseModelInputProcessor):
+ """internvl input processor."""
+
+ def __init__(self, config: PretrainedConfig, dtype) -> None:
+ self.config = config
+ self.dtype = dtype
+
+ vision_config = config.vision_config
+ self.image_size = vision_config.image_size
+ self.patch_size = vision_config.patch_size
+ self.num_patches = (self.image_size // self.patch_size)**2
+ self.num_positions = self.num_patches + 1
+ self.vision_token_num = self.num_patches // 4
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_multimodals: List[Dict[str, Any]] = None,
+ **kwargs) -> PreprocessInputResult:
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values'].to(self.dtype)
+ offset = input_mm['offset']
+ image_token_id = input_mm.get('image_token_id', 0)
+ num_pad = input_mm['image_tokens']
+ if isinstance(num_pad, torch.Tensor):
+ num_pad = num_pad.item()
+
+ mm_data = MultiModalTensor(
+ data=pixel_values,
+ start=offset,
+ end=offset + num_pad,
+ meta=dict(image_token_id=image_token_id))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
diff --git a/lmdeploy/pytorch/models/internvl_patch.py b/lmdeploy/pytorch/models/internvl_patch.py
new file mode 100644
index 0000000000..d13ad2d39b
--- /dev/null
+++ b/lmdeploy/pytorch/models/internvl_patch.py
@@ -0,0 +1,96 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from transformers.configuration_utils import PretrainedConfig
+
+
+class InternVisionEmbeddings(nn.Module):
+ """mono vision."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(
+ torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device), )
+
+ self.patch_embedding = nn.Conv2d(in_channels=3,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ dtype=dtype,
+ device=device)
+
+ self.num_patches = (self.image_size // self.patch_size)**2
+ self.num_positions = self.num_patches + 1
+
+ self.position_embedding = nn.Parameter(
+ torch.empty(1,
+ self.num_positions,
+ self.embed_dim,
+ dtype=dtype,
+ device=device))
+
+ def _get_pos_embed(self, pos_embed, H, W):
+ target_dtype = pos_embed.dtype
+ pos_embed = pos_embed.float().reshape(
+ 1, self.image_size // self.patch_size,
+ self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
+ pos_embed = F.interpolate(pos_embed,
+ size=(H, W),
+ mode='bicubic',
+ align_corners=False)
+ pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2,
+ 1).to(target_dtype)
+ return pos_embed
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(
+ pixel_values) # shape = [*, channel, width, height]
+ batch_size, _, height, width = patch_embeds.shape
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+ class_embeds = self.class_embedding.expand(batch_size, 1,
+ -1).to(target_dtype)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ position_embedding = torch.cat([
+ self.position_embedding[:, :1, :],
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height,
+ width)
+ ],
+ dim=1)
+ embeddings = embeddings + position_embedding.to(target_dtype)
+ return embeddings
+
+
+class InternVisionPatchModel(nn.Module):
+ """mono vision."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.embeddings = InternVisionEmbeddings(config,
+ dtype=dtype,
+ device=device)
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ ):
+ if len(pixel_values.shape) != 4:
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
+
+ hidden_states = self.embeddings(pixel_values)[:, 1:]
+ return hidden_states
diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py
index f38c5ef02b..1a98c02f03 100644
--- a/lmdeploy/pytorch/models/llama.py
+++ b/lmdeploy/pytorch/models/llama.py
@@ -29,7 +29,8 @@ def __init__(self,
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
-
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
# packed qkv
self.qkv_proj = build_qkv_proj(
hidden_size,
@@ -40,6 +41,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
+ num_replicate_kv_heads=num_replicate_kv_heads,
)
# rotary embedding
@@ -450,22 +452,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
param = params_dict[name]
load_weight(param, loaded_weight)
-
-
-class LlavaLlamaForCausalLM(LlamaForCausalLM):
- """llava llama for causallm."""
-
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- """load weights."""
-
- new_weights = dict()
- for key, val in weights:
- if key.startswith('model.vision_tower'):
- continue
- if key.startswith('model.mm_projector'):
- continue
- if key.startswith('model.image_newline'):
- continue
- new_weights[key] = val
-
- super().load_weights(new_weights.items())
diff --git a/lmdeploy/pytorch/models/llava.py b/lmdeploy/pytorch/models/llava.py
index 56cb5ca675..751f7343ec 100644
--- a/lmdeploy/pytorch/models/llava.py
+++ b/lmdeploy/pytorch/models/llava.py
@@ -1,17 +1,443 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Any, Iterable, List, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
+import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_outputs import BaseModelOutputWithPooling
+from transformers.models.llava.configuration_llava import LlavaConfig
+from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor,
+ PreprocessInputResult)
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
+from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_qkv_proj,
+ build_rowwise_linear)
+from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .patch import build_model_from_hf_config
from .utils.cudagraph import CudaGraphMixin
+from .utils.model import DeployModelMixin
-class LlavaForConditionalGeneration(nn.Module, CudaGraphMixin):
+class LlavaMultiModalProjector(nn.Module):
+
+ def __init__(self,
+ config: LlavaConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ from transformers.activations import ACT2FN
+
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size,
+ config.text_config.hidden_size,
+ bias=True,
+ dtype=dtype,
+ device=device)
+ self.act = ACT2FN[config.projector_hidden_act]
+ self.linear_2 = nn.Linear(config.text_config.hidden_size,
+ config.text_config.hidden_size,
+ bias=True,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, image_features):
+ hidden_states = self.linear_1(image_features)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+class CLIPVisionEmbeddings(nn.Module):
+ """clip vision embedding."""
+
+ def __init__(self,
+ config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(
+ torch.empty(self.embed_dim, dtype=dtype, device=device))
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ bias=False,
+ dtype=dtype,
+ device=device,
+ )
+
+ self.num_patches = (self.image_size // self.patch_size)**2
+ self.num_positions = self.num_patches + 1
+ self.position_embedding = nn.Embedding(
+ self.num_positions,
+ self.embed_dim,
+ dtype=dtype,
+ device=device,
+ )
+ self.register_buffer('position_ids',
+ torch.arange(self.num_positions,
+ device=device).expand((1, -1)),
+ persistent=False)
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
+ width: int) -> torch.Tensor:
+ """This method allows to interpolate the pre-trained position
+ encodings, to be able to use the model on higher resolution images.
+
+ This method is also adapted to support torch.jit tracing.
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ position_embedding = self.position_embedding.weight.unsqueeze(0)
+ num_positions = position_embedding.shape[1] - 1
+
+ # always interpolate when tracing
+ # to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing(
+ ) and num_patches == num_positions and height == width:
+ return self.position_embedding(self.position_ids)
+
+ from transformers.utils import torch_int
+
+ class_pos_embed = position_embedding[:, :1]
+ patch_pos_embed = position_embedding[:, 1:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions,
+ sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode='bicubic',
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(self,
+ pixel_values: torch.FloatTensor,
+ interpolate_pos_encoding=False) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ if not interpolate_pos_encoding and (height != self.image_size
+ or width != self.image_size):
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f' ({self.image_size}*{self.image_size}).')
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(
+ dtype=target_dtype)) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(
+ embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embedding(
+ self.position_ids)
+ return embeddings
+
+
+class CLIPAttention(nn.Module):
+ """clip attention."""
+
+ def __init__(self,
+ config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+
+ self.qkv_proj = build_qkv_proj(
+ self.embed_dim,
+ num_q_heads=self.num_heads,
+ num_kv_heads=self.num_heads,
+ head_size=self.head_dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ )
+
+ self.scale = self.head_dim**-0.5
+
+ # o_proj
+ self.out_proj = build_rowwise_linear(self.embed_dim,
+ self.embed_dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ """forward."""
+ # qkv proj
+ qkv_states = self.qkv_proj(hidden_states)
+ q, k, v = self.qkv_proj.split_qkv(qkv_states)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ if attention_mask is not None and causal_attention_mask is not None:
+ attn_mask = attention_mask + causal_attention_mask
+ elif causal_attention_mask is not None:
+ attn_mask = causal_attention_mask
+ else:
+ attn_mask = attention_mask
+
+ attn_output = F.scaled_dot_product_attention(q,
+ k,
+ v,
+ attn_mask=attn_mask,
+ scale=self.scale)
+
+ # o proj
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.flatten(-2, -1)
+ attn_output = self.out_proj(attn_output)
+ return attn_output
+
+
+class CLIPMLP(nn.Module):
+ """clip mlp."""
+
+ def __init__(self,
+ config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ quantization_config = getattr(config, 'quantization_config', None)
+ from transformers.activations import ACT2FN
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = build_colwise_linear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+ self.fc2 = build_rowwise_linear(
+ config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """forward."""
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class CLIPEncoderLayer(nn.Module):
+ """clip encoder layer."""
+
+ def __init__(self,
+ config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = CLIPAttention(config, dtype=dtype, device=device)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+ self.mlp = CLIPMLP(config, dtype=dtype, device=device)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ causal_attention_mask: torch.Tensor,
+ ):
+ """forward."""
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class CLIPEncoder(nn.Module):
+ """clip encoder."""
+
+ def __init__(self,
+ config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([
+ CLIPEncoderLayer(config, dtype=dtype, device=device)
+ for _ in range(config.num_hidden_layers)
+ ])
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ vision_feature_layer: int = -1,
+ ):
+ """forward."""
+ hidden_states = inputs_embeds
+ num_vision_layers = len(self.layers) + vision_feature_layer + 1
+ for _, encoder_layer in enumerate(self.layers[:num_vision_layers]):
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ )
+
+ hidden_states = layer_outputs
+
+ return hidden_states
+
+
+class CLIPVisionTransformer(nn.Module):
+ """clip vision transformer."""
+
+ def __init__(self,
+ config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = CLIPVisionEmbeddings(config,
+ dtype=dtype,
+ device=device)
+ self.pre_layrnorm = nn.LayerNorm(embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+ self.encoder = CLIPEncoder(config, dtype=dtype, device=device)
+ self.post_layernorm = nn.LayerNorm(embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ interpolate_pos_encoding: bool = False,
+ vision_feature_layer: int = -1,
+ ) -> BaseModelOutputWithPooling:
+ """forward."""
+ hidden_states = self.embeddings(
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+ hidden_states = self.pre_layrnorm(hidden_states)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ vision_feature_layer=vision_feature_layer)
+
+ last_hidden_state = encoder_outputs
+ pooled_output = last_hidden_state[:, 0, :]
+ pooled_output = self.post_layernorm(pooled_output)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=None,
+ attentions=None,
+ )
+
+
+class CLIPVisionModel(nn.Module):
+ """clip vision model."""
+
+ def __init__(self,
+ config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.vision_model = CLIPVisionTransformer(config,
+ dtype=dtype,
+ device=device)
+
+ def forward(self,
+ pixel_values: torch.FloatTensor,
+ interpolate_pos_encoding: bool = False,
+ vision_feature_layer: int = -1,
+ **kwargs):
+ """forward."""
+ return self.vision_model(
+ pixel_values,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ vision_feature_layer=vision_feature_layer)
+
+
+def build_vision_model(vision_config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ """build vision model."""
+ model_type = vision_config.model_type
+
+ if model_type == 'clip_vision_model':
+ return CLIPVisionModel(vision_config, dtype, device)
+ else:
+ raise NotImplementedError(f'<{model_type}> is not implemented.')
+
+
+class LlavaForConditionalGeneration(nn.Module, CudaGraphMixin,
+ DeployModelMixin):
def __init__(self,
config: PretrainedConfig,
@@ -22,19 +448,67 @@ def __init__(self,
self.config = config
self.ctx_mgr = ctx_mgr
text_config = config.text_config
+
+ self.vision_tower = build_vision_model(config.vision_config,
+ dtype=dtype,
+ device=device)
+
self.language_model = build_model_from_hf_config(text_config,
dtype=dtype,
device=device)
+ self.multi_modal_projector = LlavaMultiModalProjector(config,
+ dtype=dtype,
+ device=device)
+
+ self.input_processor = LLavaInputProcessor(config, dtype)
+
+ def get_image_features(self,
+ pixel_values,
+ vision_feature_layer: int = -1,
+ vision_feature_select_strategy: str = 'default'):
+ """get image features."""
+ selected_image_feature = self.vision_tower(
+ pixel_values, vision_feature_layer=vision_feature_layer)[0]
+ if vision_feature_select_strategy == 'default':
+ selected_image_feature = selected_image_feature[:, 1:]
+ elif vision_feature_select_strategy == 'full':
+ selected_image_feature = selected_image_feature
+ else:
+ raise ValueError(
+ f'Unexpected select feature strategy: {vision_feature_select_strategy}' # noqa: E501
+ )
+ image_features = self.multi_modal_projector(selected_image_feature)
+ image_features = image_features.flatten(0, 1)[None]
+
+ return image_features
+
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: List[List[torch.Tensor]],
attn_metadata: Any = None,
+ pixel_values: torch.Tensor = None,
+ image_mask: torch.Tensor = None,
inputs_embeds: torch.Tensor = None,
**kwargs,
):
+ if inputs_embeds is None:
+ image_features = None
+ if pixel_values is not None:
+ vision_feature_layer = self.config.vision_feature_layer
+ select_strategy = self.config.vision_feature_select_strategy
+ image_features = self.get_image_features(
+ pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=select_strategy)
+ inputs_embeds = self.language_model.get_input_embeddings()(
+ input_ids)
+ if pixel_values is not None:
+ inputs_embeds.masked_scatter_(image_mask[..., None],
+ image_features)
+
return self.language_model.forward(input_ids=input_ids,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
@@ -59,6 +533,27 @@ def prepare_inputs_for_generation(
input_ids = context.input_ids
position_ids = context.position_ids
attn_metadata = context.attn_metadata
+
+ # vision inputs
+ pixel_values = None
+ image_mask = None
+ if context.input_multimodals is not None:
+ pixel_values = [
+ input_mm.get('image', [])
+ for input_mm in context.input_multimodals
+ ]
+ # flatten batch
+ pixel_values = [
+ data for im_data in pixel_values for data in im_data
+ ]
+ if len(pixel_values) > 0:
+ image_token_id = pixel_values[0].meta['image_token_id']
+ image_mask = input_ids == image_token_id
+ pixel_values = torch.cat([data.data for data in pixel_values])
+ else:
+ pixel_values = None
+ image_mask = None
+
# get inputs from context
vision_embeddings = context.input_embeddings
vision_embedding_indexing = context.input_embedding_indexing
@@ -75,18 +570,404 @@ def prepare_inputs_for_generation(
position_ids=position_ids,
past_key_values=past_key_values,
attn_metadata=attn_metadata,
+ pixel_values=pixel_values,
+ image_mask=image_mask,
inputs_embeds=inputs_embeds,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""load weights."""
- prefix_length = len('language_model.')
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ('.qkv_proj', '.q_proj', 'q'),
+ ('.qkv_proj', '.k_proj', 'k'),
+ ('.qkv_proj', '.v_proj', 'v'),
+ ]
+
+ # vis model
+ lang_prefix = 'language_model.'
+ params_dict = dict(self.named_parameters())
+ for name, loaded_weight in weights:
+ if name.startswith(lang_prefix):
+ continue
+
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ load_weight(param, loaded_weight, shard_id=shard_id)
+ break
+ else:
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
+
+ # language model
+ prefix_length = len(lang_prefix)
new_weights = dict()
for key, val in weights:
- if not key.startswith('language_model.'):
+ if not key.startswith(lang_prefix):
continue
new_key = key[prefix_length:]
new_weights[new_key] = val
self.language_model.load_weights(new_weights.items())
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+
+class LLavaInputProcessor(BaseModelInputProcessor):
+ """llava input processor."""
+
+ def __init__(self, config: PretrainedConfig, dtype) -> None:
+ self.config = config
+ self.dtype = dtype
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_multimodals: List[Dict[str, Any]] = None,
+ **kwargs) -> PreprocessInputResult:
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values'].to(self.dtype)
+ offset = input_mm['offset']
+ image_token_id = input_mm.get('image_token_id', 0)
+ num_pad = input_mm['image_tokens']
+ if isinstance(num_pad, torch.Tensor):
+ num_pad = num_pad.item()
+
+ mm_data = MultiModalTensor(
+ data=pixel_values,
+ start=offset,
+ end=offset + num_pad,
+ meta=dict(image_token_id=image_token_id))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
+
+
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+
+ from transformers.image_processing_utils import select_best_resolution
+
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError('grid_pinpoints should be a list of tuples or lists')
+
+ if not isinstance(image_size, (list, tuple)):
+ image_size = image_size.tolist()
+
+ height, width = select_best_resolution(image_size, grid_pinpoints)
+ return height // patch_size, width // patch_size
+
+
+def unpad_image(tensor, original_size):
+ """Unpads a PyTorch tensor of a padded and resized image."""
+ if not isinstance(original_size, (list, tuple)):
+ original_size = original_size.tolist()
+ original_height, original_width = original_size
+ current_height, current_width = tensor.shape[1:]
+
+ original_aspect_ratio = original_width / original_height
+ current_aspect_ratio = current_width / current_height
+
+ if original_aspect_ratio > current_aspect_ratio:
+ scale_factor = current_width / original_width
+ new_height = int(round(original_height * scale_factor, 7))
+ padding = (current_height - new_height) // 2
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
+ else:
+ scale_factor = current_height / original_height
+ new_width = int(round(original_width * scale_factor, 7))
+ padding = (current_width - new_width) // 2
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
+
+ return unpadded_tensor
+
+
+def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
+ """Calculate the number of patches after the preprocessing for images of
+ any resolution."""
+ from transformers.image_processing_utils import select_best_resolution
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError('grid_pinpoints should be a list of tuples or lists')
+
+ if not isinstance(image_size, (list, tuple)):
+ image_size = image_size.tolist()
+
+ best_resolution = select_best_resolution(image_size, grid_pinpoints)
+ height, width = best_resolution
+
+ num_patches = (height // patch_size) * (width // patch_size)
+ # add the base patch
+ num_patches += 1
+ return num_patches
+
+
+class LlavaNextForConditionalGeneration(LlavaForConditionalGeneration):
+
+ def __init__(self,
+ config: PretrainedConfig,
+ ctx_mgr: StepContextManager,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__(config=config,
+ ctx_mgr=ctx_mgr,
+ dtype=dtype,
+ device=device)
+ self.image_newline = nn.Parameter(
+ torch.empty(config.text_config.hidden_size,
+ dtype=dtype,
+ device=device))
+ self.input_processor = LLavaNextInputProcessor(config, dtype)
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_sizes: torch.Tensor,
+ vision_feature_layer: int,
+ vision_feature_select_strategy: str,
+ ):
+ # ! infer image_num_patches from image_sizes
+ image_num_patches = [
+ image_size_to_num_patches(
+ image_size=imsize,
+ grid_pinpoints=self.config.image_grid_pinpoints,
+ patch_size=self.config.vision_config.image_size,
+ ) for imsize in image_sizes
+ ]
+ if pixel_values.dim() == 5:
+ # stacked if input is
+ # (batch_size, num_patches, num_channels, height, width)
+ _pixel_values_list = [
+ pix_val[:num_patch]
+ for pix_val, num_patch in zip(pixel_values, image_num_patches)
+ ]
+ pixel_values = torch.cat(_pixel_values_list, dim=0)
+ elif pixel_values.dim() != 4:
+ # otherwise has to be stacked from list of
+ # (num_patches, num_channels, height, width)
+ raise ValueError(f'pixel_values of shape {pixel_values.shape}, '
+ 'expect to be of 4 or 5 dimensions')
+
+ selected_image_feature = self.vision_tower(
+ pixel_values, vision_feature_layer=vision_feature_layer)[0]
+ if vision_feature_select_strategy == 'default':
+ selected_image_feature = selected_image_feature[:, 1:]
+ elif vision_feature_select_strategy == 'full':
+ selected_image_feature = selected_image_feature
+ image_features = self.multi_modal_projector(selected_image_feature)
+ image_features = torch.split(image_features, image_num_patches, dim=0)
+ return image_features
+
+ def pack_image_features(self,
+ image_features,
+ image_sizes,
+ vision_feature_select_strategy,
+ image_newline=None):
+
+ new_image_features = []
+ feature_lens = []
+ for image_idx, image_feature in enumerate(image_features):
+ if image_feature.shape[0] > 1:
+ base_image_feature = image_feature[0]
+ image_feature = image_feature[1:]
+ height = width = (self.config.vision_config.image_size //
+ self.config.vision_config.patch_size)
+
+ if vision_feature_select_strategy == 'default':
+ expected_num_patches = height * width
+ elif vision_feature_select_strategy == 'full':
+ expected_num_patches = height * width + 1
+ if expected_num_patches != base_image_feature.shape[0]:
+ raise ValueError('The number of patches is '
+ 'not consistent with the image size.')
+
+ (num_patch_height,
+ num_patch_width) = get_anyres_image_grid_shape(
+ image_sizes[image_idx],
+ self.config.image_grid_pinpoints,
+ self.config.vision_config.image_size,
+ )
+ image_feature = image_feature.view(num_patch_height,
+ num_patch_width, height,
+ width, -1)
+ image_feature = image_feature.permute(4, 0, 2, 1,
+ 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature,
+ image_sizes[image_idx])
+ if image_newline is not None:
+ image_feature = torch.cat(
+ (
+ image_feature,
+ image_newline[:, None, None].expand(
+ *image_feature.shape[:-1], 1).to(
+ image_feature.dtype),
+ ),
+ dim=-1,
+ )
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ image_feature = torch.cat((base_image_feature, image_feature),
+ dim=0)
+ else:
+ image_feature = image_feature[0]
+ if image_newline is not None:
+ image_feature = torch.cat(
+ (image_feature, image_newline[None].to(image_feature)),
+ dim=0)
+ new_image_features.append(image_feature)
+ feature_lens.append(image_feature.size(0))
+ image_features = torch.cat(new_image_features, dim=0)
+ return image_features
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ past_key_values: List[List[torch.Tensor]],
+ attn_metadata: Any = None,
+ pixel_values: torch.Tensor = None,
+ image_sizes: torch.Tensor = None,
+ image_mask: torch.Tensor = None,
+ inputs_embeds: torch.Tensor = None,
+ **kwargs,
+ ):
+ if inputs_embeds is None:
+ image_features = None
+ if pixel_values is not None:
+ vision_feature_layer = self.config.vision_feature_layer
+ select_strategy = self.config.vision_feature_select_strategy
+ image_sizes = image_sizes.tolist()
+ image_features = self.get_image_features(
+ pixel_values,
+ image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=select_strategy)
+ image_features = self.pack_image_features(
+ image_features,
+ image_sizes,
+ vision_feature_select_strategy=select_strategy,
+ image_newline=self.image_newline,
+ )
+ image_features = image_features[None]
+ inputs_embeds = self.language_model.get_input_embeddings()(
+ input_ids)
+ if pixel_values is not None:
+ inputs_embeds.masked_scatter_(image_mask[..., None],
+ image_features)
+
+ return self.language_model.forward(input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ attn_metadata=attn_metadata)
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+ def prepare_inputs_for_generation(
+ self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: torch.Tensor = None,
+ context: StepContext = None,
+ ):
+ """prepare input."""
+ input_ids = context.input_ids
+ position_ids = context.position_ids
+ attn_metadata = context.attn_metadata
+
+ # vision inputs
+ pixel_values = None
+ image_sizes = None
+ image_mask = None
+ if context.input_multimodals is not None:
+ img_mms = [
+ input_mm.get('image', [])
+ for input_mm in context.input_multimodals
+ ]
+ # flatten batch
+ img_mms = [data for im_data in img_mms for data in im_data]
+ if len(img_mms) > 0:
+ image_token_id = img_mms[0].meta['image_token_id']
+ image_mask = input_ids == image_token_id
+ pixel_values = torch.cat(
+ [data.data.flatten(0, 1) for data in img_mms])
+ image_sizes = torch.cat(
+ [data.meta['image_sizes'] for data in img_mms])
+ else:
+ pixel_values = None
+ image_sizes = None
+
+ # get inputs from context
+ vision_embeddings = context.input_embeddings
+ vision_embedding_indexing = context.input_embedding_indexing
+
+ if vision_embeddings is not None and len(vision_embeddings) > 0:
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+ inputs_embeds[:,
+ vision_embedding_indexing, :] = vision_embeddings.to(
+ inputs_embeds)
+
+ return dict(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ attn_metadata=attn_metadata,
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ image_mask=image_mask,
+ inputs_embeds=inputs_embeds,
+ )
+
+
+class LLavaNextInputProcessor(BaseModelInputProcessor):
+ """llava input processor."""
+
+ def __init__(self, config: PretrainedConfig, dtype) -> None:
+ self.config = config
+ self.dtype = dtype
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_multimodals: List[Dict[str, Any]] = None,
+ **kwargs) -> PreprocessInputResult:
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values'].to(self.dtype)
+ image_sizes = input_mm['image_sizes']
+ offset = input_mm['offset']
+ image_token_id = input_mm.get('image_token_id', 0)
+ num_pad = input_mm['image_tokens']
+ if isinstance(num_pad, torch.Tensor):
+ num_pad = num_pad.item()
+
+ mm_data = MultiModalTensor(data=pixel_values,
+ start=offset,
+ end=offset + num_pad,
+ meta=dict(
+ image_sizes=image_sizes,
+ image_token_id=image_token_id))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
diff --git a/lmdeploy/pytorch/models/medusa.py b/lmdeploy/pytorch/models/medusa.py
index bc9d086dc9..28da3bad55 100644
--- a/lmdeploy/pytorch/models/medusa.py
+++ b/lmdeploy/pytorch/models/medusa.py
@@ -10,6 +10,7 @@
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .utils.cudagraph import CudaGraphMixin
+from .utils.model import DeployModelMixin
vicuna_7b_stage2 = [(0, ), (0, 0), (1, ), (0, 1), (0, 0, 0), (1, 0), (2, ),
(0, 2), (0, 0, 1), (0, 3), (3, ), (0, 1, 0), (2, 0), (4, ),
@@ -138,7 +139,7 @@ def forward(self, x):
return x + self.act(self.linear(x))
-class MedusaModel(nn.Module, CudaGraphMixin):
+class MedusaModel(nn.Module, CudaGraphMixin, DeployModelMixin):
"""The medusa model architecture."""
packed_modules_mapping = {
diff --git a/lmdeploy/pytorch/models/minicpmv26.py b/lmdeploy/pytorch/models/minicpmv26.py
index 725e97d9d7..9e47c56437 100644
--- a/lmdeploy/pytorch/models/minicpmv26.py
+++ b/lmdeploy/pytorch/models/minicpmv26.py
@@ -29,7 +29,8 @@ def __init__(self,
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
-
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
# packed qkv
self.qkv_proj = build_qkv_proj(
hidden_size,
@@ -40,7 +41,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
# rotary embedding
self.apply_rotary_pos_emb = ApplyRotaryEmb()
@@ -226,7 +227,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -246,7 +246,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/mistral.py b/lmdeploy/pytorch/models/mistral.py
index 04af4c8526..962cdb3d2b 100644
--- a/lmdeploy/pytorch/models/mistral.py
+++ b/lmdeploy/pytorch/models/mistral.py
@@ -223,7 +223,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -240,7 +239,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
@@ -420,22 +418,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
param = params_dict[name]
load_weight(param, loaded_weight)
-
-
-class LlavaMistralForCausalLM(MistralForCausalLM):
- """llava forcausallm."""
-
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- """load weights."""
-
- new_weights = dict()
- for key, val in weights:
- if key.startswith('model.vision_tower'):
- continue
- if key.startswith('model.mm_projector'):
- continue
- if key.startswith('model.image_newline'):
- continue
- new_weights[key] = val
-
- super().load_weights(new_weights.items())
diff --git a/lmdeploy/pytorch/models/mixtral.py b/lmdeploy/pytorch/models/mixtral.py
index d98efee712..be414d7bff 100644
--- a/lmdeploy/pytorch/models/mixtral.py
+++ b/lmdeploy/pytorch/models/mixtral.py
@@ -8,7 +8,7 @@
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
build_rotary_embedding)
from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear
-from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK
+from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .utils.cudagraph import CudaGraphMixin
@@ -22,7 +22,7 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
- quantization_config = None
+ quantization_config = getattr(config, 'quantization_config', None)
num_heads = config.num_attention_heads
num_key_value_heads = config.num_key_value_heads
@@ -112,6 +112,7 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
@@ -124,11 +125,12 @@ def __init__(self,
dtype=dtype,
device=device,
is_tp=False,
+ quant_config=None,
)
self.softmax_topk = SoftmaxTopK(self.top_k)
- self.experts = FusedMoE(
+ self.experts = build_fused_moe(
self.hidden_dim,
self.ffn_dim,
self.num_experts,
@@ -137,6 +139,7 @@ def __init__(self,
dtype=dtype,
device=device,
all_reduce=True,
+ quant_config=quantization_config,
)
def forward(self, hidden_states: torch.Tensor):
@@ -166,7 +169,7 @@ def __init__(self,
device: torch.device = None):
super().__init__()
self.layer_idx = layer_idx
- quantization_config = None
+ quantization_config = getattr(config, 'quantization_config', None)
# build attention layer
self.self_attn = MixtralAttention(config, dtype=dtype, device=device)
@@ -182,12 +185,10 @@ def __init__(self,
device=device)
# build attention layer norm
- self.post_attention_layernorm = RMSNorm(
- config.hidden_size,
- config.rms_norm_eps,
- quant_config=quantization_config,
- dtype=dtype,
- device=device)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ config.rms_norm_eps,
+ dtype=dtype,
+ device=device)
def forward(
self,
@@ -376,12 +377,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
num_experts = self.config.num_local_experts
expert_params_mapping = []
for exp_id in range(num_experts):
- gate_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.w1.weight', exp_id, 'gate')
- up_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.w3.weight', exp_id, 'up')
- down_param = ('.experts.down_weights',
- f'.experts.{exp_id}.w2.weight', exp_id, 'down')
+ gate_param = ('.experts.gate_up', f'.experts.{exp_id}.w1', exp_id,
+ 'gate')
+ up_param = ('.experts.gate_up', f'.experts.{exp_id}.w3', exp_id,
+ 'up')
+ down_param = ('.experts.down', f'.experts.{exp_id}.w2', exp_id,
+ 'down')
expert_params_mapping += [gate_param, up_param, down_param]
params_dict = dict(self.named_parameters())
diff --git a/lmdeploy/pytorch/models/mllama.py b/lmdeploy/pytorch/models/mllama.py
index 2596fe5299..15b3e9732b 100644
--- a/lmdeploy/pytorch/models/mllama.py
+++ b/lmdeploy/pytorch/models/mllama.py
@@ -3,23 +3,61 @@
import torch
from torch import nn
+from torch.nn import functional as F
from transformers.models.llama import LlamaConfig
-from transformers.models.mllama.modeling_mllama import MllamaTextConfig
+from transformers.models.mllama.modeling_mllama import (MllamaTextConfig,
+ MllamaVisionConfig)
+from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor,
+ PreprocessInputResult)
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
-from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
- SiluAndMul, build_rotary_embedding)
-from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear,
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
+from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, LayerNorm, RMSNorm,
+ RopeType, SiluAndMul, build_rotary_embedding)
+from lmdeploy.pytorch.nn.linear import (build_colwise_linear,
+ build_merged_colwise_linear,
build_qkv_proj, build_rowwise_linear)
from lmdeploy.pytorch.nn.rotary_embedding import Llama3Parameters
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
-from .utils.cudagraph import CudaGraphMixin
+from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin, next_power_of_2
+from .utils.model import DeployModelMixin
MLLAMA_IMAGE_TOKEN_ID = 128256
MLLAMA_IMAGE_TOKEN = '<|image|>'
+def _prepare_aspect_ratio_attention_mask(
+ aspect_ratio_mask: torch.Tensor,
+ num_patches: int,
+ target_length: int,
+ dtype: torch.dtype,
+) -> torch.Tensor:
+ # Expand aspect ratio mask to target_length
+ batch_size, max_num_tiles = aspect_ratio_mask.shape
+ attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1,
+ 1).to(dtype)
+ attention_mask = attention_mask.repeat(1, 1, target_length, 1)
+
+ # Mask padding patches
+ pad_patches = target_length - num_patches
+ attention_mask[:, :, -pad_patches:] = 0
+
+ # Invert the mask (0 -> 1, 1 -> 0)
+ attention_mask = 1 - attention_mask
+
+ # Reshape to 2D and create 4D attention mask
+ # (batch_size, 1, max_num_tiles * target_length,
+ # max_num_tiles * target_length)
+ attention_mask = attention_mask.reshape(batch_size,
+ max_num_tiles * target_length, 1)
+ attention_mask = attention_mask * attention_mask.transpose(
+ -1, -2) * torch.finfo(dtype).min
+ attention_mask = attention_mask.unsqueeze(1)
+
+ return attention_mask
+
+
class LlamaAttention(nn.Module):
"""Rewrite module of LlamaAttention."""
@@ -157,6 +195,7 @@ def __init__(self,
self.head_dim,
num_kv_heads=self.num_key_value_heads,
v_head_size=self.head_dim,
+ causal=False,
)
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
@@ -579,7 +618,542 @@ def get_logits(self, hidden_states: torch.Tensor):
return self.lm_head(hidden_states)
-class MllamaForConditionalGeneration(nn.Module, CudaGraphMixin):
+class MllamaPrecomputedPositionEmbedding(nn.Module):
+ """vis position embedding."""
+
+ def __init__(self,
+ config: MllamaVisionConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.max_num_tiles = config.max_num_tiles
+ self.max_aspect_ratio_id = config.max_aspect_ratio_id
+ self.config = config
+ self.num_patches = (config.image_size // config.patch_size)**2 + 1
+ self.hidden_size = config.hidden_size
+
+ self.gate = nn.Parameter(torch.empty(1, dtype=dtype, device=device))
+
+ # position embedding
+ self.embedding = nn.Parameter(
+ torch.empty(self.num_patches,
+ self.hidden_size,
+ dtype=dtype,
+ device=device))
+
+ # tile position embedding
+ self.tile_embedding = nn.Embedding(self.max_aspect_ratio_id + 1,
+ self.max_num_tiles *
+ self.num_patches * self.hidden_size,
+ dtype=dtype,
+ device=device)
+
+ self._weight_inited = False
+
+ def _init_weight(self):
+ """init weight."""
+ if self._weight_inited:
+ return
+
+ gate_tanh = self.gate.tanh()
+ gated_position_embedding = (1 - gate_tanh) * self.embedding
+ self.gate_tanh = gate_tanh
+ self.gated_position_embedding = gated_position_embedding.view(
+ 1, 1, self.num_patches, self.hidden_size)
+
+ self._weight_inited = True
+
+ def forward(self, hidden_state: torch.Tensor,
+ aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
+ """forward."""
+ self._init_weight()
+
+ # position embeddings
+ hidden_state = hidden_state + self.gated_position_embedding
+
+ # precomputed tile position embeddings
+ tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
+ batch_size = hidden_state.shape[0]
+ tile_position_embedding = tile_position_embedding.reshape(
+ batch_size, self.max_num_tiles, self.num_patches, self.hidden_size)
+ gated_tile_position_embedding = (self.gate_tanh *
+ tile_position_embedding)
+ hidden_state = hidden_state + gated_tile_position_embedding
+
+ return hidden_state
+
+
+class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
+
+ def __init__(self,
+ config: MllamaVisionConfig,
+ is_gated: bool = True,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.max_num_tiles = config.max_num_tiles
+ self.hidden_size = config.hidden_size
+ self.max_aspect_ratio_id = config.max_aspect_ratio_id
+ self.is_gated = is_gated
+
+ self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1,
+ self.max_num_tiles * self.hidden_size,
+ dtype=dtype,
+ device=device)
+ if is_gated:
+ self.gate = nn.Parameter(torch.empty(1, dtype=dtype,
+ device=device))
+
+ self._weight_inited = False
+
+ def _init_weight(self):
+ """init weight."""
+ if self._weight_inited:
+ return
+
+ gate_tanh = self.gate.tanh()
+ self.gate_tanh = gate_tanh
+
+ self._weight_inited = True
+
+ def forward(self, hidden_state: torch.Tensor,
+ aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
+ self._init_weight()
+ embeddings = self.embedding(aspect_ratio_ids)
+ embeddings = embeddings.reshape(-1, self.max_num_tiles, 1,
+ self.hidden_size)
+
+ if self.is_gated:
+ embeddings = embeddings * self.gate_tanh
+
+ hidden_state = hidden_state + embeddings
+ return hidden_state
+
+
+class MllamaVisionAttention(nn.Module):
+ """mllama vision attention."""
+
+ def __init__(self,
+ config: MllamaVisionConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.attention_heads
+ self.head_dim = config.hidden_size // config.attention_heads
+
+ # packed qkv
+ self.qkv_proj = build_qkv_proj(
+ self.embed_dim,
+ num_q_heads=self.num_heads,
+ num_kv_heads=self.num_heads,
+ head_size=self.head_dim,
+ bias=False,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ )
+
+ # o_proj
+ self.o_proj = build_rowwise_linear(self.num_heads * self.head_dim,
+ self.embed_dim,
+ bias=False,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size = hidden_state.size(0)
+ qkv_states = self.qkv_proj(hidden_state)
+ qkv_states = qkv_states.flatten(0, -2)
+ query, key, value = self.qkv_proj.split_qkv(qkv_states)
+
+ query = query.unflatten(0, (batch_size, -1))
+ key = key.unflatten(0, (batch_size, -1))
+ value = value.unflatten(0, (batch_size, -1))
+ q_seq_len = query.shape[1]
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ attn_output = F.scaled_dot_product_attention(query,
+ key,
+ value,
+ attn_mask=attention_mask)
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
+
+ output = self.o_proj(attn_output)
+
+ return output
+
+
+class MllamaVisionMLP(nn.Module):
+ """mllama vision mlp."""
+
+ def __init__(self,
+ config: MllamaVisionConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ from transformers.activations import ACT2FN
+ self.config = config
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = build_colwise_linear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+ self.fc2 = build_rowwise_linear(config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class MllamaVisionEncoderLayer(nn.Module):
+ """vision encoder layer."""
+
+ def __init__(self,
+ config: MllamaVisionConfig,
+ is_gated: bool,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.is_gated = is_gated
+ self.self_attn = MllamaVisionAttention(config,
+ dtype=dtype,
+ device=device)
+ self.mlp = MllamaVisionMLP(config, dtype=dtype, device=device)
+
+ self.input_layernorm = LayerNorm(self.hidden_size,
+ eps=config.norm_eps,
+ dtype=dtype,
+ device=device)
+ self.post_attention_layernorm = LayerNorm(self.hidden_size,
+ eps=config.norm_eps,
+ dtype=dtype,
+ device=device)
+
+ if is_gated:
+ self.gate_attn = nn.Parameter(
+ torch.empty(1, dtype=dtype, device=device))
+ self.gate_ffn = nn.Parameter(
+ torch.empty(1, dtype=dtype, device=device))
+
+ self._weight_inited = not is_gated
+
+ def _init_weight(self):
+ """init weight."""
+ if self._weight_inited:
+ return
+
+ self.gate_attn_tanh = self.gate_attn.tanh()
+ self.gate_ffn_tanh = self.gate_ffn.tanh()
+
+ self._weight_inited = True
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ """forward."""
+ self._init_weight()
+
+ # Self Attention
+ residual = hidden_state
+ hidden_state = self.input_layernorm(hidden_state)
+ hidden_state = self.self_attn(hidden_state,
+ attention_mask=attention_mask)
+ if self.is_gated:
+ hidden_state = self.gate_attn_tanh * hidden_state
+ hidden_state = residual + hidden_state
+
+ # Feed forward
+ residual = hidden_state
+ hidden_state = self.post_attention_layernorm(hidden_state)
+ hidden_state = self.mlp(hidden_state)
+ if self.is_gated:
+ hidden_state = self.gate_ffn_tanh * hidden_state
+ hidden_state = residual + hidden_state
+
+ outputs = hidden_state
+
+ return outputs
+
+
+class MllamaVisionEncoder(nn.Module):
+ """vision encoder."""
+
+ def __init__(self,
+ config: MllamaVisionConfig,
+ num_layers=32,
+ is_gated=False,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([
+ MllamaVisionEncoderLayer(config,
+ is_gated,
+ dtype=dtype,
+ device=device) for _ in range(num_layers)
+ ])
+ self.gradient_checkpointing = False
+ self.config = config
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ """forward."""
+ encoder_states = ()
+ for encoder_layer in self.layers:
+ encoder_states = encoder_states + (hidden_states, )
+ hidden_states = encoder_layer(
+ hidden_state=hidden_states,
+ attention_mask=attention_mask,
+ )
+ encoder_states = encoder_states + (hidden_states, )
+
+ return hidden_states, encoder_states
+
+
+class MllamaVisionModel(nn.Module):
+ """vision model."""
+
+ def __init__(self,
+ config: MllamaVisionConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+
+ self.config = config
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+ self.hidden_size = config.hidden_size
+ self.intermediate_layers_indices = config.intermediate_layers_indices
+ self.dtype = dtype
+
+ self.num_patches = (self.image_size // self.patch_size)**2 + 1
+ self.scale = config.hidden_size**-0.5
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.hidden_size,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding='valid',
+ bias=False,
+ dtype=dtype,
+ device=device,
+ )
+
+ self.class_embedding = nn.Parameter(
+ torch.empty(self.hidden_size, dtype=dtype, device=device))
+ self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
+ config,
+ dtype=dtype,
+ device=device,
+ )
+
+ self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( # noqa: E501
+ config,
+ is_gated=True,
+ dtype=dtype,
+ device=device,
+ )
+ self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( # noqa: E501
+ config,
+ is_gated=True,
+ dtype=dtype,
+ device=device,
+ )
+
+ # layer norms
+ self.layernorm_pre = nn.LayerNorm(
+ self.hidden_size,
+ dtype=dtype,
+ device=device,
+ )
+ self.layernorm_post = nn.LayerNorm(
+ self.hidden_size,
+ dtype=dtype,
+ device=device,
+ )
+
+ # encoders
+ self.transformer = MllamaVisionEncoder(
+ config,
+ config.num_hidden_layers,
+ is_gated=False,
+ dtype=dtype,
+ device=device,
+ )
+ self.global_transformer = MllamaVisionEncoder(
+ config,
+ config.num_global_layers,
+ is_gated=True,
+ dtype=dtype,
+ device=device,
+ )
+
+ def apply_class_embedding(self,
+ hidden_state: torch.Tensor) -> torch.Tensor:
+ batch_size, _, hidden_size = hidden_state.shape
+ class_embedding = self.class_embedding.expand(batch_size, 1,
+ hidden_size)
+ hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
+ return hidden_state
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ aspect_ratio_ids: torch.Tensor,
+ aspect_ratio_mask: torch.Tensor,
+ ):
+ """forward."""
+ (batch_size, num_concurrent_media, num_tiles, num_channels, height,
+ width) = pixel_values.shape
+
+ pixel_values = pixel_values.reshape(
+ batch_size * num_concurrent_media * num_tiles, num_channels,
+ height, width)
+ aspect_ratio_ids = aspect_ratio_ids.reshape(
+ batch_size * num_concurrent_media, -1)
+
+ # Patch embedding
+ patch_embeds = self.patch_embedding(pixel_values.to(self.dtype))
+ hidden_state = patch_embeds.flatten(2).transpose(1, 2)
+
+ # Tile embeddings
+ _, num_patches, dim = hidden_state.shape
+ hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
+ num_tiles, -1, dim)
+ hidden_state = self.pre_tile_positional_embedding(
+ hidden_state, aspect_ratio_ids)
+
+ # Add cls token
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media * num_tiles, num_patches, dim)
+ hidden_state = self.apply_class_embedding(hidden_state)
+ num_patches += 1
+
+ # Position embeddings
+ hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
+ num_tiles, num_patches, dim)
+ hidden_state = self.gated_positional_embedding(hidden_state,
+ aspect_ratio_ids)
+
+ hidden_state = self.layernorm_pre(hidden_state)
+
+ # Compute the number of tokens to pad
+ num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
+ # Compute padding tuple for pad function
+ padding = (
+ 0, 0, 0, num_padding_patches
+ ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
+ # Pad the tensor
+ hidden_state = F.pad(hidden_state, padding, mode='constant', value=0)
+ slice_index = -num_padding_patches if num_padding_patches > 0 else None
+
+ # Prepare attention mask
+ attention_mask = aspect_ratio_mask.reshape(
+ batch_size * num_concurrent_media, -1)
+ attention_mask = _prepare_aspect_ratio_attention_mask(
+ aspect_ratio_mask=attention_mask,
+ num_patches=self.num_patches,
+ target_length=hidden_state.shape[2],
+ dtype=self.dtype,
+ )
+
+ # Apply encoder
+ hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1,
+ dim)
+ output = self.transformer(
+ hidden_state,
+ attention_mask=attention_mask,
+ )
+ hidden_state = output[0]
+
+ hidden_state = self.layernorm_post(hidden_state)
+
+ # Apply global encoder
+ hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
+ num_tiles,
+ num_patches + num_padding_patches,
+ dim)
+ hidden_state = self.post_tile_positional_embedding(
+ hidden_state, aspect_ratio_ids)
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media,
+ num_tiles * (num_patches + num_padding_patches), dim)
+ global_output = self.global_transformer(
+ hidden_state,
+ attention_mask=attention_mask,
+ )
+ hidden_state = global_output[0]
+
+ # Remove padding form hidden state
+ hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
+ num_tiles,
+ num_patches + num_padding_patches,
+ dim)
+ hidden_state = hidden_state[:, :, :slice_index]
+ hidden_state = hidden_state.reshape(batch_size, num_concurrent_media,
+ num_tiles, num_patches, dim)
+
+ # Collect intermediate layer outputs from encoder output
+ all_intermediate_hidden_states = output[1]
+ all_intermediate_hidden_states = [
+ all_intermediate_hidden_states[i]
+ for i in self.intermediate_layers_indices
+ ]
+ intermediate_hidden_states = torch.stack(
+ all_intermediate_hidden_states, dim=-1)
+
+ # Remove padding from intermediate hidden states
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
+ batch_size * num_concurrent_media, num_tiles,
+ num_patches + num_padding_patches, -1)
+ intermediate_hidden_states = intermediate_hidden_states[:, :, :
+ slice_index]
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
+ batch_size, num_concurrent_media, num_tiles, num_patches, -1)
+
+ # Concatenate final hidden state and intermediate hidden states
+ hidden_state = torch.cat([hidden_state, intermediate_hidden_states],
+ dim=-1)
+
+ return hidden_state
+
+
+class MllamaForConditionalGeneration(nn.Module, CudaGraphMixin,
+ DeployModelMixin):
"""rewrote model of MllamaForConditionalGeneration."""
packed_modules_mapping = {
@@ -602,16 +1176,32 @@ def __init__(self,
super().__init__()
self.config = config
self.ctx_mgr = ctx_mgr
+
+ self.vision_model = MllamaVisionModel(
+ config.vision_config,
+ dtype=dtype,
+ device=device,
+ )
# build MllamaForCausalLM
self.language_model = MllamaForCausalLM(config.text_config,
dtype=dtype,
device=device)
+
+ self.multi_modal_projector = build_rowwise_linear(
+ config.vision_config.vision_output_dim,
+ config.text_config.hidden_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ )
self.dtype = dtype
- def flat_encoder_result(self, cross_attention_states: torch.Tensor,
- attn_metadata: Any, input_ids: torch.LongTensor):
+ # preprocessor
+ self.input_processor = MLlamaInputProcessor(self.config, dtype)
+
+ def flat_encoder_result(self, attn_metadata: Any,
+ input_ids: torch.LongTensor):
# since every state share the same shape
- cross_attention_states = torch.cat(cross_attention_states, 0)
full_text_row_masked_out_mask = torch.ones(
(attn_metadata.q_seqlens.sum(), 1), dtype=torch.bool)
start_pos = 0
@@ -621,39 +1211,51 @@ def flat_encoder_result(self, cross_attention_states: torch.Tensor,
full_text_row_masked_out_mask[start_pos:img_id] = False
start_pos += q_seq_len
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
- cross_attention_states.device)
+ input_ids.device)
- return cross_attention_states, full_text_row_masked_out_mask
+ return full_text_row_masked_out_mask
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: List[List[torch.Tensor]],
- cross_attention_states: Optional[torch.Tensor] = None,
+ pixel_values: torch.Tensor = None,
+ aspect_ratio_ids: torch.Tensor = None,
+ aspect_ratio_mask: torch.Tensor = None,
attn_metadata: Any = None,
inputs_embeds: torch.Tensor = None,
cross_attn_metadata: Any = None,
**kwargs,
):
"""model forward, return logits."""
+
if cross_attn_metadata is None:
full_text_row_masked_out_mask = None
# FIXME basically, we want to inference
# text requests and image requests separately
- elif cross_attention_states is None and (
- cross_attn_metadata.kv_seqlens is None
- or int(cross_attn_metadata.kv_seqlens.sum()) == 0):
+ elif pixel_values is None and (cross_attn_metadata.kv_seqlens is None):
full_text_row_masked_out_mask = None
elif cross_attn_metadata.is_decoding:
- cross_attention_states = None
- full_text_row_masked_out_mask = torch.ones(
- (attn_metadata.q_seqlens.sum(), 1),
- dtype=torch.bool,
- device=input_ids.device)
+ full_text_row_masked_out_mask = input_ids.new_ones(
+ input_ids.size(-1), 1)
else:
- cross_attention_states, full_text_row_masked_out_mask = \
- self.flat_encoder_result(cross_attention_states, cross_attn_metadata, input_ids) # noqa
+ full_text_row_masked_out_mask = self.flat_encoder_result(
+ cross_attn_metadata, input_ids) # noqa
+
+ cross_attention_states = None
+ if pixel_values is not None:
+ cross_attention_states = self.vision_model(
+ pixel_values=pixel_values,
+ aspect_ratio_ids=aspect_ratio_ids,
+ aspect_ratio_mask=aspect_ratio_mask,
+ )
+ cross_attention_states = self.multi_modal_projector(
+ cross_attention_states)
+ _, bsz, _, _, image_token_dim = tuple(cross_attention_states.shape)
+ cross_attention_states = cross_attention_states.view(
+ bsz, -1, image_token_dim)
+
hidden_states = self.language_model(
input_ids=input_ids,
position_ids=position_ids,
@@ -670,15 +1272,6 @@ def get_logits(self, hidden_states: torch.Tensor):
"""compute logits of the model output."""
return self.language_model.get_logits(hidden_states)
- def support_cuda_graph(
- self,
- input_ids: torch.Tensor,
- **kwargs,
- ):
- """support cudagraph."""
-
- return False
-
def get_input_embeddings(self):
"""get input embeddings."""
return self.language_model.model.get_input_embeddings()
@@ -694,14 +1287,35 @@ def prepare_inputs_for_generation(
input_ids = context.input_ids
position_ids = context.position_ids
attn_metadata = context.attn_metadata
- cross_attention_states = context.cross_attention_states
- if cross_attention_states is not None:
- cross_attention_states = [
- t.to(input_ids.device) for t in cross_attention_states
- if t is not None
- ]
cross_attn_metadata = context.cross_attn_metadata
+ # cross_attn_metadata is None when inputs without image
+ if cross_attn_metadata is not None and int(
+ cross_attn_metadata.kv_seqlens.sum()) == 0:
+ cross_attn_metadata.kv_seqlens = None
+
+ device = input_ids.device
+
+ # process image input
+ pixel_values = None
+ aspect_ratio_ids = None
+ aspect_ratio_mask = None
+ if context.input_multimodals is not None:
+ pixel_values = []
+ aspect_ratio_ids = []
+ aspect_ratio_mask = []
+ batched_image_data = [
+ input_mm['image'] for input_mm in context.input_multimodals
+ ]
+ for image_data in batched_image_data:
+ for data in image_data:
+ pixel_values.append(data.data)
+ aspect_ratio_ids.append(data.meta['aspect_ratio_ids'])
+ aspect_ratio_mask.append(data.meta['aspect_ratio_mask'])
+ pixel_values = torch.cat(pixel_values, dim=0).to(device)
+ aspect_ratio_ids = torch.cat(aspect_ratio_ids, dim=0).to(device)
+ aspect_ratio_mask = torch.cat(aspect_ratio_mask, dim=0).to(device)
+
# process vision embeddings
vision_embeddings = context.input_embeddings
vision_embedding_indexing = context.input_embedding_indexing
@@ -719,7 +1333,9 @@ def prepare_inputs_for_generation(
past_key_values=past_key_values,
attn_metadata=attn_metadata,
inputs_embeds=inputs_embeds,
- cross_attention_states=cross_attention_states,
+ pixel_values=pixel_values,
+ aspect_ratio_ids=aspect_ratio_ids,
+ aspect_ratio_mask=aspect_ratio_mask,
cross_attn_metadata=cross_attn_metadata,
)
@@ -742,8 +1358,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if ('rotary_emb.cos_cached' in name
or 'rotary_emb.sin_cached' in name):
continue
- if 'vision_model' in name or 'multi_modal_projector' in name:
- continue
if self.config.text_config.tie_word_embeddings and 'lm_head.weight' in name: # noqa
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
@@ -756,3 +1370,161 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
param = params_dict[name]
load_weight(param, loaded_weight)
+
+ def support_cuda_graph(
+ self,
+ input_ids: torch.Tensor,
+ attn_metadata: Any,
+ cross_attn_metadata: Any,
+ **kwargs,
+ ):
+ """support cudagraph."""
+
+ if not attn_metadata.is_decoding:
+ return False
+
+ if cross_attn_metadata is None:
+ return False
+
+ if cross_attn_metadata.kv_seqlens is None:
+ return False
+
+ return True
+
+ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
+ """make cudagraph buffers from forward inputs."""
+ input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta,
+ **kwargs)
+
+ device = graph_meta.device
+ max_batches = graph_meta.max_batchs
+ input_buffers['cross_kv_seqlens'] = torch.zeros(max_batches,
+ dtype=torch.int64,
+ device=device)
+
+ return input_buffers
+
+ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
+ """fill cudagraph buffers from forward inputs."""
+ input_buffers = graph_meta.input_buffers
+
+ new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta,
+ **kwargs)
+
+ attn_metadata = new_inputs['attn_metadata']
+ cross_attn_metadata = new_inputs['cross_attn_metadata']
+ block_offsets = attn_metadata.block_offsets
+ batch_size, _ = block_offsets.size()
+
+ kv_seqlens = cross_attn_metadata.kv_seqlens
+ if kv_seqlens.data_ptr() != input_buffers['cross_kv_seqlens'].data_ptr(
+ ):
+ input_buffers['cross_kv_seqlens'].zero_()
+ input_buffers['cross_kv_seqlens'][:batch_size] = kv_seqlens
+
+ new_batch_size = next_power_of_2(batch_size)
+ cross_attn_metadata.block_offsets = input_buffers[
+ 'block_offsets'][:new_batch_size]
+ cross_attn_metadata.q_start_loc = input_buffers[
+ 'q_start_loc'][:new_batch_size]
+ cross_attn_metadata.q_seqlens = input_buffers[
+ 'q_seqlens'][:new_batch_size]
+ cross_attn_metadata.kv_seqlens = input_buffers[
+ 'cross_kv_seqlens'][:new_batch_size]
+
+ new_inputs['cross_attn_metadata'] = cross_attn_metadata
+ return new_inputs
+
+ def update_model_metas(self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None):
+ """update model meta."""
+ model_metas = context.model_metas
+ if model_metas is None:
+ batch_size = context.q_seqlens.size(0)
+ model_metas = [dict(cross_kv_len=0) for _ in range(batch_size)]
+
+ if context.is_decoding:
+ return model_metas
+
+ vision_inputs = context.vision_inputs
+ if vision_inputs is None:
+ return model_metas
+
+ input_mms = vision_inputs.input_multimodals
+ if input_mms is None:
+ return model_metas
+
+ config = self.config.vision_config
+ image_size = config.image_size
+ patch_size = config.patch_size
+ wh = image_size // patch_size
+ img_kv_len = wh * wh + 1
+ img_kv_len = img_kv_len * 4
+
+ new_model_metas = []
+ for idx, input_mm in enumerate(input_mms):
+ if input_mm is None:
+ new_model_metas.append(model_metas[idx])
+ images = input_mm['image']
+ num_img = len(images)
+
+ cross_kv_len = 0
+ if model_metas[idx] is not None:
+ cross_kv_len = model_metas[idx].get('cross_kv_len',
+ cross_kv_len)
+ cross_kv_len += img_kv_len * num_img
+ new_model_metas.append(dict(cross_kv_len=cross_kv_len))
+
+ return model_metas
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+
+class MLlamaInputProcessor(BaseModelInputProcessor):
+ """mllama input processor."""
+
+ def __init__(self, config: LlamaConfig, dtype: torch.dtype) -> None:
+ self.config = config
+ self.dtype = dtype
+
+ vision_config = self.config.vision_config
+ image_size = vision_config.image_size
+ patch_size = vision_config.patch_size
+ wh = image_size // patch_size
+ encoder_len = wh * wh + 1
+ encoder_len = encoder_len * 4
+ self.encoder_len = encoder_len
+
+ def preprocess_input(self, input_ids, input_multimodals, **kwargs):
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values']
+ aspect_ratio_ids = input_mm['aspect_ratio_ids']
+ aspect_ratio_mask = input_mm['aspect_ratio_mask']
+ offset = input_mm['offset']
+
+ if pixel_values.dtype != self.dtype:
+ pixel_values = pixel_values.to(self.dtype)
+
+ mm_data = MultiModalTensor(
+ data=pixel_values,
+ start=offset,
+ end=offset + 1,
+ encoder_len=self.encoder_len,
+ meta=dict(aspect_ratio_ids=aspect_ratio_ids,
+ aspect_ratio_mask=aspect_ratio_mask))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py
index af1b23cfc0..b47ff77b3a 100644
--- a/lmdeploy/pytorch/models/module_map.py
+++ b/lmdeploy/pytorch/models/module_map.py
@@ -82,17 +82,19 @@
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM'
})
+# deepseek-v3
+MODULE_MAP.update({
+ 'DeepseekV3ForCausalLM':
+ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM'
+})
+
# llava
MODULE_MAP.update(
{
- 'LlavaLlamaForCausalLM':
- f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlavaLlamaForCausalLM',
- 'LlavaMistralForCausalLM':
- f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.LlavaMistralForCausalLM',
'LlavaForConditionalGeneration':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration', # noqa: E501
'LlavaNextForConditionalGeneration': # noqa: E501
- f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration'
+ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaNextForConditionalGeneration' # noqa: E501
})
# qwen
@@ -158,7 +160,7 @@
# phi3 vision
MODULE_MAP.update({
'Phi3VForCausalLM':
- f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.Phi3VForCausalLM',
+ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3_v.Phi3VForCausalLM',
})
# phi-3.5-moe
diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py
index 9da1b9f4ea..9604b19af5 100644
--- a/lmdeploy/pytorch/models/patch.py
+++ b/lmdeploy/pytorch/models/patch.py
@@ -8,6 +8,7 @@
import torch
from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_utils import load_state_dict
from lmdeploy.utils import get_logger
@@ -250,6 +251,10 @@ def add_adapters(model: torch.nn.Module,
ranks, scalings = get_ranks_and_scalings(target_name,
adapter_cfgs,
device=device)
+ # split in case target_name has '.' like 'attention.wo'
+ # which cannot be used as name of a module
+ # and it's not aligned with key in model.packed_modules_mapping
+ target_name = target_name.split('.')[-1]
found_mods, pack_idx = find_all_target(model, target_name)
sum_rank = ranks.sum().item()
@@ -295,7 +300,9 @@ def add_adapters(model: torch.nn.Module,
for name, path in adapters.items():
adapter_id = adapter_id_map[name]
checkpoint_path = f'{path}/adapter_model.bin'
- state_dict = torch.load(checkpoint_path, map_location=device)
+ if not osp.exists(checkpoint_path):
+ checkpoint_path = f'{path}/adapter_model.safetensors'
+ state_dict = load_state_dict(checkpoint_path, map_location=device)
if hasattr(model, 'load_lora_weights'):
model.load_lora_weights(state_dict.items(), adapter_id=adapter_id)
diff --git a/lmdeploy/pytorch/models/phi3.py b/lmdeploy/pytorch/models/phi3.py
index f9477fdab8..988fee11e5 100644
--- a/lmdeploy/pytorch/models/phi3.py
+++ b/lmdeploy/pytorch/models/phi3.py
@@ -226,7 +226,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -243,7 +242,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
@@ -435,7 +433,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
param = params_dict[name]
load_weight(param, loaded_weight)
-
-
-class Phi3VForCausalLM(Phi3ForCausalLM):
- ...
diff --git a/lmdeploy/pytorch/models/phi3_moe.py b/lmdeploy/pytorch/models/phi3_moe.py
index 080f5e996c..7d0572513a 100644
--- a/lmdeploy/pytorch/models/phi3_moe.py
+++ b/lmdeploy/pytorch/models/phi3_moe.py
@@ -7,7 +7,7 @@
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, LayerNorm, RopeType
from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear
-from lmdeploy.pytorch.nn.moe import FusedMoE
+from lmdeploy.pytorch.nn.moe import build_fused_moe
from lmdeploy.pytorch.nn.rotary_embedding import (LongRoPEScalingParameters,
build_rotary_embedding)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
@@ -180,7 +180,7 @@ def __init__(self,
is_tp=False,
)
- self.experts = FusedMoE(
+ self.experts = build_fused_moe(
self.hidden_dim,
self.ffn_dim,
self.num_experts,
@@ -448,12 +448,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
num_experts = self.config.num_local_experts
expert_params_mapping = []
for exp_id in range(num_experts):
- gate_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.w1.weight', exp_id, 'gate')
- up_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.w3.weight', exp_id, 'up')
- down_param = ('.experts.down_weights',
- f'.experts.{exp_id}.w2.weight', exp_id, 'down')
+ gate_param = ('.experts.gate_up', f'.experts.{exp_id}.w1', exp_id,
+ 'gate')
+ up_param = ('.experts.gate_up', f'.experts.{exp_id}.w3', exp_id,
+ 'up')
+ down_param = ('.experts.down', f'.experts.{exp_id}.w2', exp_id,
+ 'down')
expert_params_mapping += [gate_param, up_param, down_param]
params_dict = dict(self.named_parameters())
diff --git a/lmdeploy/pytorch/models/phi3_v.py b/lmdeploy/pytorch/models/phi3_v.py
new file mode 100644
index 0000000000..c4bf72c767
--- /dev/null
+++ b/lmdeploy/pytorch/models/phi3_v.py
@@ -0,0 +1,476 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig
+
+from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor,
+ PreprocessInputResult)
+from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
+from lmdeploy.pytorch.nn.linear import build_rowwise_linear
+from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
+
+from .phi3 import Phi3ForCausalLM, Phi3Model
+from .utils.model import DeployModelMixin
+
+CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(attention_dropout=0.0,
+ dropout=0.0,
+ hidden_act='quick_gelu',
+ hidden_size=1024,
+ image_size=336,
+ initializer_factor=1.0,
+ initializer_range=0.02,
+ intermediate_size=4096,
+ layer_norm_eps=1e-05,
+ num_attention_heads=16,
+ num_channels=3,
+ num_hidden_layers=24,
+ patch_size=14,
+ projection_dim=768)
+
+
+class Phi3ImageEmbedding(nn.Module):
+ """image embedding."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ wte=None,
+ dtype: torch.dtype = None,
+ device: torch.device = None,
+ **kwargs):
+ super().__init__()
+ self.config = config
+ hidden_size = config.n_embd if hasattr(
+ config, 'n_embd') else config.hidden_size
+
+ self.wte = wte
+
+ if (isinstance(config.img_processor, dict) and
+ config.img_processor.get('name', None) == 'clip_vision_model'):
+ assert 'model_name' in config.img_processor, (
+ 'model_name must be provided for CLIPVisionModel')
+ assert 'image_dim_out' in config.img_processor, (
+ 'image_dim_out must be provided for CLIPVisionModel')
+ assert 'num_img_tokens' in config.img_processor, (
+ 'num_img_tokens must be provided for CLIPVisionModel')
+ assert config.img_processor[
+ 'model_name'] == 'openai/clip-vit-large-patch14-336'
+ clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
+ self.img_processor = CLIPVisionModel(clip_config).to(device).to(
+ dtype)
+ image_dim_out = config.img_processor['image_dim_out']
+ self.num_img_tokens = config.img_processor['num_img_tokens']
+ else:
+ raise NotImplementedError(
+ f'img_processor = {config.img_processor}, not implemented')
+
+ self.image_dim_out = image_dim_out
+ self.img_sizes = None
+
+ self.use_hd_transform = kwargs.get('use_hd_transform', False)
+ self.with_learnable_separator = kwargs.get('with_learnable_separator',
+ False)
+ self.hd_transform_order = kwargs.get('hd_transform_order', 'glb_sub')
+ # with_hd_transform and with_learnable_separator should have same value
+ assert (self.use_hd_transform == self.with_learnable_separator), (
+ 'use_hd_transform and with_learnable_separator '
+ 'should have same value')
+ if self.with_learnable_separator:
+ assert self.use_hd_transform, (
+ 'learnable separator is only for hd transform')
+ # 1024 * 4, merge spatial to channel dimension
+ self.glb_GN = nn.Parameter(
+ torch.empty([1, 1, self.image_dim_out * 4],
+ dtype=dtype,
+ device=device))
+ self.sub_GN = nn.Parameter(
+ torch.empty([1, 1, 1, self.image_dim_out * 4],
+ dtype=dtype,
+ device=device))
+
+ projection_cls = kwargs.get('projection_cls', 'linear')
+ if projection_cls == 'linear':
+ self.img_projection = nn.Linear(image_dim_out,
+ hidden_size,
+ dtype=dtype,
+ device=device)
+ elif projection_cls == 'mlp' and self.use_hd_transform:
+ dim_projection = hidden_size
+ depth = 2
+ layers = [
+ nn.Linear(image_dim_out * 4,
+ dim_projection,
+ dtype=dtype,
+ device=device)
+ ]
+ for _ in range(1, depth):
+ layers.extend([
+ nn.GELU(),
+ nn.Linear(dim_projection,
+ dim_projection,
+ dtype=dtype,
+ device=device)
+ ])
+ self.img_projection = nn.Sequential(*layers)
+ elif projection_cls == 'mlp':
+ dim_projection = hidden_size
+ depth = 2
+ layers = [
+ nn.Linear(image_dim_out,
+ dim_projection,
+ dtype=dtype,
+ device=device)
+ ]
+ for _ in range(1, depth):
+ layers.extend([
+ nn.GELU(),
+ nn.Linear(dim_projection,
+ dim_projection,
+ dtype=dtype,
+ device=device)
+ ])
+ self.img_projection = nn.Sequential(*layers)
+ else:
+ raise NotImplementedError(
+ f'projection_cls = {projection_cls}, not implemented')
+
+ self.vocab_size = config.vocab_size
+ self.img_features = None
+
+ if isinstance(config.img_processor, dict):
+ self.layer_idx = config.img_processor.get('layer_idx', -2)
+ self.type_feature = config.img_processor.get(
+ 'type_feature', 'patch')
+ else:
+ self.layer_idx = -2
+ self.type_feature = 'patch'
+
+ def get_img_features(self,
+ img_embeds: torch.FloatTensor) -> torch.FloatTensor:
+ LAYER_IDX = self.layer_idx
+ TYPE_FEATURE = self.type_feature
+
+ img_processor_output = self.img_processor(img_embeds,
+ output_hidden_states=True)
+ img_feature = img_processor_output.hidden_states[LAYER_IDX]
+
+ if TYPE_FEATURE == 'patch':
+ patch_feature = img_feature[:, 1:]
+ return patch_feature
+
+ if TYPE_FEATURE == 'cls_patch':
+ return img_feature
+
+ raise NotImplementedError
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ pixel_values: torch.FloatTensor,
+ image_sizes=None,
+ image_mask: torch.Tensor = None,
+ ) -> torch.FloatTensor:
+ """forward."""
+
+ target_device = pixel_values.device
+ target_dtype = pixel_values.dtype
+
+ img_embeds = pixel_values
+ img_sizes = image_sizes
+ img_sizes = img_sizes.cpu()
+
+ if self.use_hd_transform and img_sizes is not None and len(img_sizes):
+ assert img_embeds.ndim == 5, f'img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform' # noqa E501
+ # img_embeds: (num_images, max_num_crops, 3, H, W)
+ # img_sizes: (num_images, 2).view(1, -1)
+
+ bs = img_embeds.shape[0]
+ # Nx(HW)xC
+ img_features = self.get_img_features(img_embeds.flatten(0, 1))
+ base_feat_height = base_feat_width = int(
+ img_features.shape[1]**0.5)
+
+ assert base_feat_height == 24 and base_feat_width == 24, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect 24x24 features for hd transform' # noqa E501
+
+ # bs x max_num_crops x (24x24) x C
+ img_features = img_features.view(
+ bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
+ C = self.image_dim_out
+ H = base_feat_height
+
+ output_imgs = []
+ output_len = []
+ # training is tensor, inference is list
+ if isinstance(img_sizes, torch.Tensor):
+ img_sizes = img_sizes.view(-1, 2)
+ for _bs in range(bs):
+ h, w = img_sizes[_bs]
+ h = h // 336
+ w = w // 336
+ B_ = h * w
+
+ # 1 x (24x24) x 1024
+ global_img_feature = img_features[_bs, :1]
+
+ # 1 x 12 x 12 x 4096
+ glb_img = global_img_feature.reshape(
+ 1, H // 2, 2, H // 2, 2,
+ C).permute(0, 1, 3, 2, 4,
+ 5).reshape(1, H // 2, H // 2, 4 * C)
+ temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1)
+
+ # 1 x 156 x 4096
+ glb_img = torch.cat([glb_img, temp_glb_GN],
+ dim=2).reshape(1, -1, 4 * C)
+
+ # (max_num_crops-1) x (12x12) x C
+ sub_img = img_features[_bs, 1:]
+ # 16x574x1024
+ # get rid of padding sub_img
+ sub_img = sub_img[:B_]
+
+ # (num_crops, 12, 2, 12, 2, 1024)
+ # ->(num_crops, 12, 12, 2, 2, 1024)
+ # -> (num_crops, 12*12, 4*1024)
+ sub_img = (sub_img.reshape(B_, H // 2, 2, H // 2, 2,
+ C).permute(0, 1, 3, 2, 4, 5))
+ sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute(
+ 0, 1, 3, 2, 4, 5).reshape(1, h * 12, w * 12, 4 * C)
+ temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1)
+ sub_img = torch.cat([sub_img, temp_sub_GN],
+ dim=2).reshape(1, -1, 4 * C)
+ # (1, num_img_tokens, 1024*4)
+
+ # glb + sub
+ if self.hd_transform_order == 'glb_sub':
+ output_imgs.append(
+ torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
+ elif self.hd_transform_order == 'sub_glb':
+ output_imgs.append(
+ torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
+ else:
+ raise NotImplementedError(
+ f'hd_transform_order = {self.hd_transform_order}'
+ ) # noqa E501
+
+ temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
+ assert temp_len == output_imgs[-1].shape[
+ 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' # noqa E501
+ output_len.append(temp_len)
+
+ img_set_tensor = []
+ for _output_img in output_imgs:
+ img_feature_proj = self.img_projection(
+ _output_img.to(target_device).to(target_dtype))
+ img_feature_proj = img_feature_proj.flatten(0, 1)
+ img_set_tensor.append(img_feature_proj)
+ img_set_tensor = torch.cat(img_set_tensor)[None]
+ elif img_embeds.ndim == 4:
+ tt = (self.get_img_features(img_embeds).to(target_device).to(
+ target_dtype).reshape(-1, self.image_dim_out))
+ img_set_tensor = self.img_projection(
+ tt) # adapted visual features.
+ elif img_embeds.ndim == 3:
+ tt = (img_embeds.to(target_device).to(target_dtype).view(
+ -1, self.image_dim_out))
+ img_set_tensor = self.img_projection(
+ tt) # adapted visual features.
+ else:
+ raise NotImplementedError
+
+ hidden_states = self.wte(input_ids)
+
+ hidden_states.masked_scatter_(image_mask[..., None], img_set_tensor)
+
+ return hidden_states
+
+
+class Phi3VModel(Phi3Model):
+ """phi3v model."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__(config=config, dtype=dtype, device=device)
+
+ self.vision_embed_tokens = None
+ if isinstance(config.embd_layer, dict):
+ # vision embedding layer
+ embedding_config = {
+ 'embedding_cls': config.embd_layer['embedding_cls'],
+ **config.embd_layer
+ }
+ self.vision_embed_tokens = Phi3ImageEmbedding(
+ config,
+ wte=self.embed_tokens,
+ dtype=dtype,
+ device=device,
+ **embedding_config)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ attn_metadata: Any = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ image_mask: torch.Tensor = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ """Rewrite of LlamaModel.forward."""
+
+ if inputs_embeds is None and pixel_values is not None:
+ inputs_embeds = self.vision_embed_tokens(
+ input_ids,
+ pixel_values,
+ image_sizes,
+ image_mask,
+ )
+
+ return super().forward(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ attn_metadata=attn_metadata,
+ inputs_embeds=inputs_embeds,
+ )
+
+
+class Phi3VForCausalLM(Phi3ForCausalLM, DeployModelMixin):
+
+ def __init__(self,
+ config: PretrainedConfig,
+ ctx_mgr: StepContextManager,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__(config, ctx_mgr, dtype=dtype, device=device)
+ self.config = config
+ self.ctx_mgr = ctx_mgr
+ # build model
+ self.model = Phi3VModel(config, dtype=dtype, device=device)
+ # build lm_head
+ self.lm_head = build_rowwise_linear(config.hidden_size,
+ config.vocab_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+
+ self.input_processor = Phi3VInputProcessor(config, dtype)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ past_key_values: List[List[torch.Tensor]],
+ attn_metadata: Any = None,
+ pixel_values: torch.Tensor = None,
+ image_sizes: torch.Tensor = None,
+ image_mask: torch.Tensor = None,
+ inputs_embeds: torch.Tensor = None,
+ **kwargs,
+ ):
+ """forward."""
+ hidden_states = self.model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ attn_metadata=attn_metadata,
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ image_mask=image_mask,
+ inputs_embeds=inputs_embeds,
+ )
+ return hidden_states
+
+ def prepare_inputs_for_generation(
+ self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: torch.Tensor = None,
+ context: StepContext = None,
+ ):
+ """prepare input."""
+ output = super().prepare_inputs_for_generation(
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ context=context)
+
+ # vision inputs
+ pixel_values = None
+ if context.input_multimodals is not None:
+ input_mms = [
+ input_mm.get('image', [])
+ for input_mm in context.input_multimodals
+ ]
+ # flatten batch
+ input_mms = [data for im_data in input_mms for data in im_data]
+ if len(input_mms) > 0:
+ pixel_values = torch.cat([data.data for data in input_mms])
+ image_sizes = torch.cat(
+ [data.meta['image_sizes'] for data in input_mms])
+ image_token_id = input_mms[0].meta['image_token_id']
+ image_mask = output['input_ids'] == image_token_id
+ output['pixel_values'] = pixel_values
+ output['image_sizes'] = image_sizes
+ output['image_mask'] = image_mask
+
+ return output
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ """load weights."""
+ super().load_weights(weights)
+
+ vis_prefix = 'vision_embed_tokens.'
+ params_dict = dict(self.named_parameters())
+ for name, loaded_weight in weights:
+ if not (vis_prefix in name):
+ continue
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+
+class Phi3VInputProcessor(BaseModelInputProcessor):
+ """Phi3V input processor."""
+
+ def __init__(self, config: PretrainedConfig, dtype) -> None:
+ self.config = config
+ self.dtype = dtype
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_multimodals: List[Dict[str, Any]] = None,
+ **kwargs) -> PreprocessInputResult:
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values'].to(self.dtype)
+ image_sizes = input_mm['image_sizes']
+ offset = input_mm['offset']
+ image_token_id = input_mm.get('image_token_id', 0)
+ num_pad = input_mm['image_tokens']
+ if isinstance(num_pad, torch.Tensor):
+ num_pad = num_pad.item()
+
+ mm_data = MultiModalTensor(data=pixel_values,
+ start=offset,
+ end=offset + num_pad,
+ meta=dict(
+ image_sizes=image_sizes,
+ image_token_id=image_token_id))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
diff --git a/lmdeploy/pytorch/models/q_modules.py b/lmdeploy/pytorch/models/q_modules.py
index 001fab7a60..8379bb18c9 100644
--- a/lmdeploy/pytorch/models/q_modules.py
+++ b/lmdeploy/pytorch/models/q_modules.py
@@ -34,13 +34,17 @@ class QRMSNorm(nn.Module):
"""It performs traditional RMS normalization and then quantizes the output
to 8-bit integers."""
- def __init__(self, hidden_size, eps=1e-6):
+ def __init__(self, hidden_size, eps=1e-6, quant_dtype=torch.int8):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
+ self.quant_dtype = quant_dtype
@classmethod
- def from_float(cls, mod: nn.Module, initialization: bool = True):
+ def from_float(cls,
+ mod: nn.Module,
+ initialization: bool = True,
+ quant_dtype=torch.int8):
"""Class method to create a QRMSNorm instance from a floating-point
module.
@@ -49,7 +53,7 @@ def from_float(cls, mod: nn.Module, initialization: bool = True):
"""
hidden_size = mod.weight.shape[0]
eps = mod.variance_epsilon
- q_mod = cls(hidden_size, eps)
+ q_mod = cls(hidden_size, eps, quant_dtype=quant_dtype)
if initialization:
q_mod.weight = nn.Parameter(mod.weight.detach())
return q_mod
@@ -62,7 +66,10 @@ def forward(self, hidden_states):
with its scale factor.
"""
hidden_states_quant, rms_scale = rms_norm_dynamic_quant(
- hidden_states, self.weight, self.variance_epsilon)
+ hidden_states,
+ self.weight,
+ self.variance_epsilon,
+ quant_dtype=self.quant_dtype)
return QTensor(hidden_states_quant, rms_scale)
@@ -83,16 +90,18 @@ def __init__(self,
out_features: int,
bias: bool = True,
device=None,
- dtype=None) -> None:
+ dtype=None,
+ quant_dtype=torch.int8) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
+ self.quant_dtype = quant_dtype
self.register_buffer(
'weight',
torch.empty((out_features, in_features),
device=device,
- dtype=torch.int8))
+ dtype=quant_dtype))
self.register_buffer(
'scale',
torch.empty((out_features, 1), device=device, dtype=torch.float32))
@@ -103,7 +112,10 @@ def __init__(self,
self.register_parameter('bias', None)
@classmethod
- def from_float(cls, mod: nn.Module, initialization: bool = True):
+ def from_float(cls,
+ mod: nn.Module,
+ initialization: bool = True,
+ quant_dtype=torch.int8):
"""Class method to create a QLinear instance from a floating-point
module.
@@ -114,11 +126,12 @@ def from_float(cls, mod: nn.Module, initialization: bool = True):
mod.out_features,
mod.bias is not None,
device=mod.weight.device,
- dtype=mod.weight.dtype)
+ dtype=mod.weight.dtype,
+ quant_dtype=quant_dtype)
if initialization:
- weight_quant, scale = per_channel_quant(mod.weight.detach(), 8,
- torch.int8)
+ weight_quant, scale = per_channel_quant(mod.weight.detach(),
+ quant_dtype)
q_mod.weight.data = weight_quant
q_mod.scale = scale
@@ -137,7 +150,8 @@ def forward(self, input):
"""
if isinstance(input, torch.Tensor):
- input_quant, input_scale = per_token_quant_int8(input, 1e-7)
+ input_quant, input_scale = per_token_quant_int8(
+ input, 1e-7, quant_dtype=self.quant_dtype)
else:
assert isinstance(input, QTensor)
input_quant, input_scale = input.tensor, input.scale
diff --git a/lmdeploy/pytorch/models/qwen.py b/lmdeploy/pytorch/models/qwen.py
index bf856461a3..20e184bdf8 100644
--- a/lmdeploy/pytorch/models/qwen.py
+++ b/lmdeploy/pytorch/models/qwen.py
@@ -229,7 +229,6 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
- quantization_config = getattr(config, 'quantization_config', None)
self.vocab_size = config.vocab_size
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(self.vocab_size,
@@ -263,7 +262,6 @@ def __init__(self,
self.ln_f = RMSNorm(self.embed_dim,
eps=config.layer_norm_epsilon,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/qwen2.py b/lmdeploy/pytorch/models/qwen2.py
index 82be75e167..a26aa22d5a 100644
--- a/lmdeploy/pytorch/models/qwen2.py
+++ b/lmdeploy/pytorch/models/qwen2.py
@@ -29,7 +29,8 @@ def __init__(self,
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
-
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
# packed qkv
self.qkv_proj = build_qkv_proj(
hidden_size,
@@ -40,7 +41,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
# rotary embedding
self.apply_rotary_pos_emb = ApplyRotaryEmb()
@@ -224,7 +225,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -241,7 +241,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/qwen2_moe.py b/lmdeploy/pytorch/models/qwen2_moe.py
index 1aff14483a..de990592d5 100644
--- a/lmdeploy/pytorch/models/qwen2_moe.py
+++ b/lmdeploy/pytorch/models/qwen2_moe.py
@@ -13,7 +13,7 @@
SiluAndMul, build_rotary_embedding)
from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear,
build_qkv_proj, build_rowwise_linear)
-from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK
+from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .utils.cudagraph import CudaGraphMixin
@@ -185,7 +185,7 @@ def __init__(self,
self.softmax_topk = SoftmaxTopK(self.top_k)
- self.experts = FusedMoE(
+ self.experts = build_fused_moe(
self.hidden_dim,
self.ffn_dim,
self.num_experts,
@@ -280,12 +280,10 @@ def __init__(self,
device=device)
# build attention layer norm
- self.post_attention_layernorm = RMSNorm(
- config.hidden_size,
- config.rms_norm_eps,
- quant_config=quantization_config,
- dtype=dtype,
- device=device)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ config.rms_norm_eps,
+ dtype=dtype,
+ device=device)
def forward(
self,
@@ -330,7 +328,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -347,7 +344,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
@@ -531,14 +527,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
num_experts = self.config.num_experts
expert_params_mapping = []
for exp_id in range(num_experts):
- gate_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.gate_proj.weight', exp_id,
- 'gate')
- up_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.up_proj.weight', exp_id, 'up')
- down_param = ('.experts.down_weights',
- f'.experts.{exp_id}.down_proj.weight', exp_id,
- 'down')
+ gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj',
+ exp_id, 'gate')
+ up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj',
+ exp_id, 'up')
+ down_param = ('.experts.down', f'.experts.{exp_id}.down_proj',
+ exp_id, 'down')
expert_params_mapping += [gate_param, up_param, down_param]
params_dict = dict(self.named_parameters())
diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py
index b10baaa4d5..bfd6e352f1 100644
--- a/lmdeploy/pytorch/models/qwen2_vl.py
+++ b/lmdeploy/pytorch/models/qwen2_vl.py
@@ -1,18 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Any, Callable, Iterable, List, Optional, Tuple
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig
+from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor,
+ PreprocessInputResult)
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
-from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
- SiluAndMul, build_rotary_embedding)
-from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear,
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
+from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, FlashAttention,
+ LayerNorm, RMSNorm, RopeType, SiluAndMul,
+ build_rotary_embedding)
+from lmdeploy.pytorch.nn.linear import (build_colwise_linear,
+ build_merged_colwise_linear,
build_qkv_proj, build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin, next_power_of_2
+from .utils.model import DeployModelMixin
def _apply_mrope_selection(hidden_states: torch.Tensor,
@@ -254,7 +260,6 @@ def __init__(self,
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.mrope_section = config.rope_scaling['mrope_section']
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -271,7 +276,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
@@ -337,7 +341,337 @@ def get_input_embeddings(self):
return self.embed_tokens
-class Qwen2VLForConditionalGeneration(nn.Module, CudaGraphMixin):
+class PatchEmbed(nn.Module):
+ """Patch Embed."""
+
+ def __init__(self,
+ patch_size: int = 14,
+ temporal_patch_size: int = 2,
+ in_channels: int = 3,
+ embed_dim: int = 1152,
+ dtype: torch.dtype = None,
+ device: torch.device = None) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.in_channels = in_channels
+ self.embed_dim = embed_dim
+
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
+ self.proj = nn.Conv3d(in_channels,
+ embed_dim,
+ kernel_size=kernel_size,
+ stride=kernel_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(-1, self.in_channels,
+ self.temporal_patch_size,
+ self.patch_size, self.patch_size)
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(
+ -1, self.embed_dim)
+ return hidden_states
+
+
+class VisionRotaryEmbedding(nn.Module):
+ """vision rotary embedding."""
+
+ def __init__(self,
+ dim: int,
+ theta: float = 10000.0,
+ device: torch.device = None) -> None:
+ super().__init__()
+ inv_freq = 1.0 / (theta**(
+ torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim))
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen,
+ device=self.inv_freq.device,
+ dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+
+class VisionAttention(nn.Module):
+ """Vision attention."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
+ dim = config.embed_dim
+ num_heads = config.num_heads
+ head_dim = dim // num_heads
+ self.head_dim = head_dim
+
+ # packed qkv
+ self.qkv = build_qkv_proj(
+ dim,
+ num_q_heads=num_heads,
+ num_kv_heads=num_heads,
+ head_size=head_dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ )
+
+ # rotary embedding
+ self.apply_rotary_pos_emb = ApplyRotaryEmb()
+
+ # attention
+ self.attention = FlashAttention(
+ num_heads,
+ head_dim,
+ causal=False,
+ )
+
+ # o_proj
+ self.proj = build_rowwise_linear(dim,
+ dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor]
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ # qkv proj
+ qkv_states = self.qkv(hidden_states)
+ # (-1, heads, head_dim)
+ qkv_states = qkv_states.flatten(0, -2)
+ q, k, v = self.qkv.split_qkv(qkv_states)
+
+ cos, sin = rotary_pos_emb
+ q, k = self.apply_rotary_pos_emb(q, k, cos, sin)
+
+ attn_output = self.attention(
+ q,
+ k,
+ v,
+ q_start_loc=cu_seqlens[:-1],
+ q_seqlens=cu_seqlens[1:] - cu_seqlens[:-1],
+ )
+
+ attn_output = attn_output.reshape(seq_length, -1)
+
+ # o proj
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class VisionMlp(nn.Module):
+ """Vision mlp."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ from transformers.activations import ACT2FN
+ dim = config.embed_dim
+ hidden_dim = int(config.embed_dim * config.mlp_ratio)
+ quantization_config = getattr(config, 'quantization_config', None)
+ # gate up
+ self.fc1 = build_colwise_linear(
+ dim,
+ hidden_dim,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+
+ # silu and mul
+ if config.hidden_act in [
+ 'gelu', 'gelu_fast', 'quick_gelu', 'gelu_python'
+ ]:
+ self.act = nn.GELU()
+ else:
+ self.act = ACT2FN[config.hidden_act]
+
+ # down
+ self.fc2 = build_rowwise_linear(hidden_dim,
+ dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, x):
+ """forward."""
+ return self.fc2(self.act(self.fc1(x)))
+
+
+class Qwen2VLVisionBlock(nn.Module):
+ """Vision block."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ layer_idx: int,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.norm1 = LayerNorm(config.embed_dim,
+ eps=1e-6,
+ dtype=dtype,
+ device=device)
+ self.norm2 = LayerNorm(config.embed_dim,
+ eps=1e-6,
+ dtype=dtype,
+ device=device)
+
+ self.attn = VisionAttention(config, dtype=dtype, device=device)
+
+ self.mlp = VisionMlp(config, dtype=dtype, device=device)
+
+ def forward(self,
+ hidden_states,
+ cu_seqlens,
+ rotary_pos_emb,
+ residual: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ else:
+ hidden_states, residual = self.norm1(hidden_states, residual)
+
+ hidden_states = self.attn(hidden_states,
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb)
+
+ hidden_states, residual = self.norm2(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ return hidden_states, residual
+
+
+class PatchMerger(nn.Module):
+ """PatchMerger."""
+
+ def __init__(self,
+ dim: int,
+ context_dim: int,
+ spatial_merge_size: int = 2,
+ dtype: torch.dtype = None,
+ device: torch.device = None) -> None:
+ super().__init__()
+ self.hidden_size = context_dim * (spatial_merge_size**2)
+ self.ln_q = nn.LayerNorm(context_dim,
+ eps=1e-6,
+ dtype=dtype,
+ device=device)
+ self.mlp = nn.Sequential(
+ nn.Linear(self.hidden_size,
+ self.hidden_size,
+ dtype=dtype,
+ device=device),
+ nn.GELU(),
+ nn.Linear(self.hidden_size, dim, dtype=dtype, device=device),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
+ return x
+
+
+class Qwen2VisionTransformerPretrainedModel(nn.Module):
+ """Vision transformer."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.spatial_merge_size = config.spatial_merge_size
+
+ self.patch_embed = PatchEmbed(
+ patch_size=config.patch_size,
+ temporal_patch_size=config.temporal_patch_size,
+ in_channels=config.in_channels,
+ embed_dim=config.embed_dim,
+ dtype=dtype,
+ device=device,
+ )
+
+ head_dim = config.embed_dim // config.num_heads
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2,
+ device=device)
+
+ self.blocks = nn.ModuleList([
+ Qwen2VLVisionBlock(config, layer_idx, dtype=dtype, device=device)
+ for layer_idx in range(config.depth)
+ ])
+ self.merger = PatchMerger(dim=config.hidden_size,
+ context_dim=config.embed_dim,
+ spatial_merge_size=config.spatial_merge_size,
+ dtype=dtype,
+ device=device)
+
+ def rot_pos_emb(self, grid_thw):
+ """rotary position embedding."""
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(
+ torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ return rotary_pos_emb
+
+ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
+ rotary_pos_emb: torch.Tensor) -> torch.Tensor:
+ """forward."""
+ hidden_states = self.patch_embed(hidden_states)
+ cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)
+
+ residual = None
+ for blk in self.blocks:
+ hidden_states, residual = blk(hidden_states,
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ residual=residual)
+
+ hidden_states = hidden_states + residual
+
+ return self.merger(hidden_states)
+
+
+class Qwen2VLForConditionalGeneration(nn.Module, DeployModelMixin,
+ CudaGraphMixin):
"""ModelForCausalLM."""
packed_modules_mapping = {
@@ -360,6 +694,16 @@ def __init__(self,
super().__init__()
self.config = config
self.ctx_mgr = ctx_mgr
+
+ # preprocessor
+ self.input_processor = Qwen2VLInputProcessor(self.config)
+
+ # build vision model
+ self.visual = Qwen2VisionTransformerPretrainedModel(
+ config.vision_config,
+ dtype=dtype,
+ device=device,
+ )
# build model
self.model = Qwen2Model(config, dtype=dtype, device=device)
# build lm_head
@@ -377,9 +721,26 @@ def forward(
attn_metadata: Any = None,
inputs_embeds: torch.Tensor = None,
mrope_position_ids: torch.Tensor = None,
+ pixel_values: torch.Tensor = None,
+ vis_cu_seqlens: torch.Tensor = None,
+ vis_pos_emb: torch.Tensor = None,
+ image_mask: torch.Tensor = None,
**kwargs,
):
"""model forward, return logits."""
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+ if pixel_values is not None:
+ dtype = inputs_embeds.dtype
+ pixel_values = pixel_values.to(dtype)
+ vis_pos_emb = (vis_pos_emb[0].to(dtype),
+ vis_pos_emb[1].to(dtype))
+ image_embeds = self.visual(pixel_values,
+ cu_seqlens=vis_cu_seqlens,
+ rotary_pos_emb=vis_pos_emb)
+ inputs_embeds = inputs_embeds.masked_scatter(
+ image_mask[..., None], image_embeds)
+
hidden_states = self.model(
input_ids=input_ids,
position_ids=position_ids,
@@ -416,6 +777,36 @@ def prepare_inputs_for_generation(
position_ids = context.position_ids
attn_metadata = context.attn_metadata
+ pixel_values = None
+ vis_cu_seqlens = None
+ vis_pos_emb = None
+ image_mask = None
+ if context.input_multimodals is not None:
+ image_data = [
+ input_mm['image'] for input_mm in context.input_multimodals
+ ]
+
+ if len(image_data) > 0:
+ # flatten batch
+ image_data = [
+ data for im_data in image_data for data in im_data
+ ]
+ pixel_values = torch.cat([data.data for data in image_data])
+ image_token_id = image_data[0].meta['image_token_id']
+ image_mask = input_ids == image_token_id
+ grid_thw = torch.cat(
+ [data.meta['grid_thw'] for data in image_data]).cpu()
+ vis_pos_emb = self.visual.rot_pos_emb(grid_thw)
+ vis_cu_seqlens = torch.repeat_interleave(
+ grid_thw[:, 1] * grid_thw[:, 2],
+ grid_thw[:, 0]).to(pixel_values.device)
+ vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0,
+ dtype=torch.int32)
+ vis_pos_emb = vis_pos_emb.repeat(1, 2)
+ vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin())
+
+ mrope_position_ids = getattr(context, 'mrope_position_ids', None)
+
# process vision embeddings
vision_embeddings = context.input_embeddings
vision_embedding_indexing = context.input_embedding_indexing
@@ -433,7 +824,11 @@ def prepare_inputs_for_generation(
past_key_values=past_key_values,
attn_metadata=attn_metadata,
inputs_embeds=inputs_embeds,
- mrope_position_ids=context.mrope_position_ids,
+ mrope_position_ids=mrope_position_ids,
+ pixel_values=pixel_values,
+ vis_cu_seqlens=vis_cu_seqlens,
+ vis_pos_emb=vis_pos_emb,
+ image_mask=image_mask,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -450,8 +845,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
- if 'visual' in name:
- continue
if 'rotary_emb.inv_freq' in name:
continue
if ('rotary_emb.cos_cached' in name
@@ -467,8 +860,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
load_weight(param, loaded_weight, shard_id=shard_id)
break
else:
- param = params_dict[name]
- load_weight(param, loaded_weight)
+ if '.qkv.' in name:
+ param = params_dict[name]
+ q, k, v = param.weight_spliter(loaded_weight)
+ load_weight(param, q, shard_id='q')
+ load_weight(param, k, shard_id='k')
+ load_weight(param, v, shard_id='v')
+ else:
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
"""make cudagraph buffers from forward inputs."""
@@ -510,3 +910,130 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
'mrope_position_ids']
return new_inputs
+
+ def _update_model_meta_decoding(self, context: StepContext):
+ """update model meta for decoding."""
+ model_metas = context.model_metas
+ position_ids = context.position_ids
+
+ mrope_deltas = [meta['mrope_delta'] for meta in model_metas]
+ mrope_deltas = position_ids.new_tensor(mrope_deltas)
+ mrope_position_ids = position_ids + mrope_deltas[None]
+ mrope_position_ids = mrope_position_ids.expand(3, -1)
+
+ context.mrope_position_ids = mrope_position_ids
+ return model_metas
+
+ def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device):
+ """get mrope ids."""
+ t, h, w = grid_thw
+ h //= 2
+ w //= 2
+ stride = torch.tensor([h * w, w, 1], device=device)[:, None]
+ size = torch.tensor([t, h, w], device=device)[:, None]
+ pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1)
+ pos_ids = pos_ids // stride % size
+ return pos_ids
+
+ def _update_model_meta_prefilling(self, context: StepContext):
+ """update model meta for prefilling."""
+ model_metas = context.model_metas
+ input_multimodals = context.input_multimodals
+ if input_multimodals is None:
+ input_multimodals = [None] * len(model_metas)
+ position_ids = context.position_ids
+ batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist())
+ mrope_position_ids = []
+ new_model_metas = []
+ for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas,
+ input_multimodals):
+ images = []
+ if input_mm is not None:
+ images = input_mm['image']
+ if model_meta is None or 'mrope_delta' not in model_meta:
+ mrope_delta = 0
+ else:
+ mrope_delta = model_meta['mrope_delta']
+
+ pos_start = pos_ids[0].item()
+ mrope_pos_ids = pos_ids + mrope_delta
+ mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone()
+ for img in images:
+ grid_thw = img.meta['grid_thw'][0].tolist()
+ _, h, w = grid_thw
+ h //= 2
+ w //= 2
+ num_pad = img.end - img.start - max(h, w)
+ mrope_delta -= num_pad
+ fill_start = img.start - pos_start
+ fill_end = img.end - pos_start
+ img_pos_ids = self._get_multimodal_pos_ids(
+ grid_thw, pos_ids.device)
+ img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1]
+ mrope_pos_ids[:, fill_end:] -= num_pad
+ mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids
+
+ mrope_position_ids.append(mrope_pos_ids)
+ new_model_metas.append(dict(mrope_delta=mrope_delta))
+
+ mrope_position_ids = torch.cat(mrope_position_ids, dim=1)
+ context.mrope_position_ids = mrope_position_ids
+
+ return new_model_metas
+
+ def update_model_metas(self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None):
+ """update model meta."""
+ if context.is_decoding:
+ return self._update_model_meta_decoding(context)
+ else:
+ return self._update_model_meta_prefilling(context)
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+
+InputMultiModalType = List[Dict[str, Any]]
+
+
+class Qwen2VLInputProcessor(BaseModelInputProcessor):
+ """qwen2 input processor."""
+
+ def __init__(self, config: PretrainedConfig) -> None:
+ self.config = config
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_multimodals: List[Dict[str, Any]] = None,
+ **kwargs) -> PreprocessInputResult:
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values']
+ image_grid_thw = input_mm['image_grid_thw']
+ offset = input_mm['offset']
+ start = offset
+ image_token_id = input_mm.get('image_token_id', 0)
+ num_pad = input_mm['image_tokens']
+ if isinstance(num_pad, torch.Tensor):
+ num_pad = num_pad.item()
+
+ mm_data = MultiModalTensor(data=pixel_values,
+ start=start,
+ end=start + num_pad,
+ meta=dict(
+ grid_thw=image_grid_thw,
+ image_token_id=image_token_id))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
diff --git a/lmdeploy/pytorch/models/utils/model.py b/lmdeploy/pytorch/models/utils/model.py
new file mode 100644
index 0000000000..99bd4c4bfb
--- /dev/null
+++ b/lmdeploy/pytorch/models/utils/model.py
@@ -0,0 +1,46 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Iterable, List, Optional, Tuple
+
+import torch
+
+from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor
+from lmdeploy.pytorch.model_inputs import StepContext
+
+
+class DeployModelMixin:
+
+ def forward(self, *args, **kwargs):
+ """forward of model."""
+ raise NotImplementedError('Not Implemented')
+
+ def prepare_inputs_for_generation(
+ self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None,
+ ):
+ """prepare input."""
+ raise NotImplementedError('Not Implemented')
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ """load weights."""
+ raise NotImplementedError('Not Implemented')
+
+ def get_logits(self, hidden_states: torch.Tensor):
+ """compute logits of the model output."""
+ return hidden_states
+
+ def update_weights(self):
+ """update weights."""
+ pass
+
+ def update_model_metas(self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None):
+ """update model meta."""
+ return None
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return None
diff --git a/lmdeploy/pytorch/models/utils/multimodal.py b/lmdeploy/pytorch/models/utils/multimodal.py
new file mode 100644
index 0000000000..aebcaf4073
--- /dev/null
+++ b/lmdeploy/pytorch/models/utils/multimodal.py
@@ -0,0 +1,14 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Tuple
+
+from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs
+
+PreparedInputs = Tuple[List[int], MultiModalInputs]
+
+
+class MultiModalMixin:
+
+ def prepare_multimodal_input(self, input_ids, input_multimodals,
+ **kwargs) -> PreparedInputs:
+ """prepare multimodals inputs."""
+ raise NotImplementedError('prepare input not implemented.')
diff --git a/lmdeploy/pytorch/multimodal/__init__.py b/lmdeploy/pytorch/multimodal/__init__.py
new file mode 100644
index 0000000000..c3e8c6a16f
--- /dev/null
+++ b/lmdeploy/pytorch/multimodal/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .data_type import MultiModalData, MultiModalTensor
+
+__all__ = ['MultiModalData', 'MultiModalTensor']
diff --git a/lmdeploy/pytorch/multimodal/data_type.py b/lmdeploy/pytorch/multimodal/data_type.py
new file mode 100644
index 0000000000..886c7ffbd0
--- /dev/null
+++ b/lmdeploy/pytorch/multimodal/data_type.py
@@ -0,0 +1,104 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from dataclasses import dataclass, fields
+from typing import Any, Dict, List, Union
+
+import torch
+from torch import Tensor
+from torch import distributed as dist
+
+
+class MultiModalData:
+ pass
+
+
+MultiModalDataList = List[MultiModalData]
+
+NestedTensor = Union[Tensor, List[Tensor]]
+
+
+def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'):
+ """broadcast tensor."""
+ if value.device.type == 'meta':
+ value = torch.empty_like(value, device=device)
+ dist.broadcast(value, src)
+ return value
+
+
+@dataclass
+class MultiModalTensor:
+ data: NestedTensor
+ start: int
+ end: int = None
+ encoder_len: int = None
+ meta: Dict[str, Any] = None
+
+ def __post_init__(self):
+ if self.end is None:
+ self.end = self.start
+
+ def to_device(self, device: str, non_blocking: bool = False):
+ """to device."""
+ out_dict = dict()
+ for f in fields(self):
+ k = f.name
+ if k in ('data', 'meta'):
+ continue
+ v = getattr(self, k)
+ out_dict[k] = v
+
+ if isinstance(self.data, Tensor):
+ data = self.data.to(device=device, non_blocking=non_blocking)
+ else:
+ data = [
+ d.to(device=device, non_blocking=non_blocking)
+ for d in self.data
+ ]
+ out_dict['data'] = data
+
+ new_meta = None
+ if self.meta is not None:
+ new_meta = dict()
+ for k, v in self.meta.items():
+ if isinstance(v, Tensor):
+ v = v.to(device=device, non_blocking=non_blocking)
+ elif hasattr(v, 'to_device'):
+ v = v.to_device(device=device, non_blocking=non_blocking)
+ new_meta[k] = v
+
+ out_dict['meta'] = new_meta
+ return MultiModalTensor(**out_dict)
+
+ def broadcast(self):
+ """broadcast inputs tensors."""
+ out_dict = dict()
+ for f in fields(self):
+ k = f.name
+ if k in ('data', 'meta'):
+ continue
+ v = getattr(self, k)
+ out_dict[k] = v
+
+ if isinstance(self.data, Tensor):
+ data = _broadcast_tensor(self.data)
+ else:
+ data = [_broadcast_tensor(d) for d in self.data]
+ out_dict['data'] = data
+
+ new_meta = None
+ if self.meta is not None:
+ new_meta = dict()
+ for k, v in self.meta.items():
+ if isinstance(v, Tensor):
+ v = _broadcast_tensor(v)
+ self.meta[k] = v
+ elif hasattr(v, 'to_device'):
+ assert hasattr(v, 'broadcast')
+ v = v.broadcast()
+ self.meta[k] = v
+ new_meta[k] = v
+
+ out_dict['meta'] = new_meta
+ return MultiModalTensor(**out_dict)
+
+
+MultiModalInputs = Dict[str, List[MultiModalTensor]]
diff --git a/lmdeploy/pytorch/multimodal/image_type.py b/lmdeploy/pytorch/multimodal/image_type.py
new file mode 100644
index 0000000000..19211a381f
--- /dev/null
+++ b/lmdeploy/pytorch/multimodal/image_type.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from dataclasses import dataclass
+from typing import Any, ClassVar, Dict
+
+from PIL import Image
+
+from .data_type import MultiModalData
+
+
+@dataclass
+class ImageData(MultiModalData):
+ data: Image
+ loc: int
+ meta: Dict[str, Any] = None
+ type: ClassVar[str] = 'image'
diff --git a/lmdeploy/pytorch/nn/__init__.py b/lmdeploy/pytorch/nn/__init__.py
index 63df9a5ae9..4705115bf4 100644
--- a/lmdeploy/pytorch/nn/__init__.py
+++ b/lmdeploy/pytorch/nn/__init__.py
@@ -2,7 +2,7 @@
# attention module is modified from:
# https://github.com/vllm-project/vllm/blob/main/vllm/attention/
from .activation import GeluAndMul, SiluAndMul # noqa: F401
-from .attention import Attention # noqa: F401
+from .attention import Attention, FlashAttention # noqa: F401
from .norm import LayerNorm, RMSNorm # noqa: F401
from .rotary_embedding import ApplyRotaryEmb # noqa: F401
from .rotary_embedding import RopeType # noqa: F401
diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py
index 26f1034d36..684c8122f8 100644
--- a/lmdeploy/pytorch/nn/attention.py
+++ b/lmdeploy/pytorch/nn/attention.py
@@ -9,6 +9,14 @@
from .utils import get_distribute_size
+def _update_num_heads(num_heads: int, num_kv_heads: int):
+ """update heads."""
+ world_size, rank = get_world_rank()
+ num_heads = get_distribute_size(num_heads, world_size, rank)
+ num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank)
+ return num_heads, num_kv_heads
+
+
class Attention(nn.Module):
"""Attention layer."""
@@ -22,15 +30,19 @@ def __init__(
alibi: bool = False,
sliding_window: int = None,
logit_softcapping: float = None,
- replicate_kv: bool = False,
+ causal: bool = True,
**kwargs,
):
super().__init__()
- num_heads, num_kv_heads = self._update_num_heads(
- num_heads, num_kv_heads, replicate_kv)
+ if num_kv_heads is None:
+ num_kv_heads = num_heads
+ if v_head_size is None:
+ v_head_size = head_size
+ num_heads, num_kv_heads = _update_num_heads(num_heads, num_kv_heads)
layer_backend = get_backend()
- impl_builder = layer_backend.get_layer_impl_builder(OpType.Attention)
+ impl_builder = layer_backend.get_layer_impl_builder(
+ OpType.PagedAttention)
self.impl = impl_builder.build(
num_heads=num_heads,
@@ -41,18 +53,10 @@ def __init__(
alibi=alibi,
sliding_window=sliding_window,
logit_softcapping=logit_softcapping,
+ causal=causal,
**kwargs,
)
- def _update_num_heads(self, num_heads: int, num_kv_heads: int,
- replicate_kv: bool):
- """update heads."""
- world_size, rank = get_world_rank()
- num_heads = get_distribute_size(num_heads, world_size, rank)
- if not replicate_kv:
- num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank)
- return num_heads, num_kv_heads
-
def forward(
self,
query: torch.Tensor,
@@ -77,3 +81,75 @@ def forward(
v_scales_zeros=v_scales_zeros,
inplace=inplace,
)
+
+
+class FlashAttention(nn.Module):
+ """flash attention w/o paging."""
+
+ def __init__(
+ self,
+ num_heads: int,
+ head_dim: int,
+ scale: float = None,
+ num_kv_heads: int = None,
+ v_head_dim: int = None,
+ causal: bool = True,
+ sliding_window: int = None,
+ logit_softcapping: float = None,
+ **kwargs,
+ ):
+ super().__init__()
+ if num_kv_heads is None:
+ num_kv_heads = num_heads
+ if v_head_dim is None:
+ v_head_dim = head_dim
+ num_heads, num_kv_heads = _update_num_heads(num_heads, num_kv_heads)
+
+ layer_backend = get_backend()
+
+ impl_builder = layer_backend.get_layer_impl_builder(
+ OpType.FlashAttention)
+
+ self.impl = impl_builder.build(
+ num_heads=num_heads,
+ head_dim=head_dim,
+ scale=scale,
+ num_kv_heads=num_kv_heads,
+ v_head_dim=v_head_dim,
+ causal=causal,
+ sliding_window=sliding_window,
+ logit_softcapping=logit_softcapping,
+ **kwargs,
+ )
+
+ def forward(self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ q_start_loc: torch.Tensor,
+ q_seqlens: torch.Tensor,
+ kv_start_loc: torch.Tensor = None,
+ kv_seqlens: torch.Tensor = None,
+ max_q_seqlen: int = None) -> torch.Tensor:
+ """forward."""
+
+ if max_q_seqlen is None:
+ max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
+
+ if kv_start_loc is None and kv_seqlens is None:
+ kv_start_loc = q_start_loc
+ kv_seqlens = q_seqlens
+
+ assert kv_start_loc is not None
+ assert kv_seqlens is not None
+
+ return self.impl.forward(
+ query,
+ key,
+ value,
+ q_start_loc=q_start_loc,
+ q_seqlens=q_seqlens,
+ kv_start_loc=kv_start_loc,
+ kv_seqlens=kv_seqlens,
+ max_q_seqlen=max_q_seqlen,
+ )
diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py
index 08040ee00c..73d0ef918d 100644
--- a/lmdeploy/pytorch/nn/linear.py
+++ b/lmdeploy/pytorch/nn/linear.py
@@ -12,7 +12,7 @@
from ..backends import OpType, get_backend
from ..backends.lora import AdapterInfo
-from .utils import div_up, get_distribute_size
+from .utils import chunk_aligned, div_up, get_distribute_size
logger = get_logger('lmdeploy')
@@ -25,37 +25,30 @@ def _check_qkv_split_layout(layout: str):
f'but get: {layout}')
-def _chunk_align(weight: torch.Tensor, chunks: int, dim: int, align: int):
- """chunk aligned."""
- if align == 1:
- return weight.chunk(chunks, dim=dim)
- size = weight.size(dim)
- assert size % align == 0
- aligned_size = size // align
- align_per_chunk = div_up(aligned_size, chunks)
- sections = [align_per_chunk] * (chunks - 1)
- sections += [aligned_size - align_per_chunk * (chunks - 1)]
- sections = [sec * align for sec in sections]
- return weight.split(sections, dim=dim)
+_chunk_align = chunk_aligned
class QKVMixin:
"""qkv mixin."""
- def _get_qkv_out_features(self, num_q_heads: int, num_kv_heads: int,
- head_size: int, head_size_v: int):
+ def _get_qkv_out_features(self,
+ num_q_heads: int,
+ num_kv_heads: int,
+ head_size: int,
+ head_size_v: int,
+ num_replicate_kv_heads: int = 1):
"""get io features."""
- all_out_features = (num_q_heads * head_size, num_kv_heads * head_size,
- num_kv_heads * head_size_v)
+ num_kv_heads_real = num_kv_heads // num_replicate_kv_heads
+ all_out_features = (num_q_heads * head_size,
+ num_kv_heads_real * head_size,
+ num_kv_heads_real * head_size_v)
return all_out_features
- def _update_num_heads(self, num_q_heads: int, num_kv_heads: int,
- replicate_kv: bool):
+ def _update_num_heads(self, num_q_heads: int, num_kv_heads: int):
"""update num heads."""
world_size, rank = get_world_rank()
num_q_heads = get_distribute_size(num_q_heads, world_size, rank)
- if not replicate_kv:
- num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank)
+ num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank)
return num_q_heads, num_kv_heads
@@ -159,6 +152,239 @@ def weight_loader_B(self, param: nn.Parameter, loaded_weight: torch.Tensor,
param_r.copy_(loaded_weight.t())
+class BlockedF8Linear(nn.Module):
+ """blocked f8 linear."""
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ fp8_dtype: torch.dtype = torch.float8_e4m3fn,
+ colwise: bool = True,
+ is_tp: bool = False,
+ all_reduce: bool = True,
+ ):
+ super().__init__()
+ if device is None:
+ device = torch.device('cpu')
+ if dtype is None:
+ dtype = torch.float16
+ if is_tp:
+ in_features, out_features = self._get_io_features(
+ in_features, out_features, colwise)
+ impl_builder = get_backend().get_layer_impl_builder(
+ OpType.LinearBlockedF8)
+ self.impl = impl_builder.build(in_features,
+ out_features,
+ block_size=128,
+ bias=bias is not None,
+ dtype=dtype)
+ self.block_size = 128
+ self.fp8_dtype = fp8_dtype
+ weight, scale, bias = self.create_weights(in_features, out_features,
+ bias, dtype, device)
+ weight = torch.nn.Parameter(weight, requires_grad=False)
+ weight.weight_loader = self.weight_loader
+ scale = torch.nn.Parameter(scale, requires_grad=False)
+ scale.weight_loader = self.weight_loader
+ if bias is not None:
+ bias = torch.nn.Parameter(bias, requires_grad=False)
+ bias.weight_loader = self.weight_loader
+ self.register_parameter('weight', weight)
+ self.register_parameter('scale', scale)
+ self.register_parameter('bias', bias)
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.lora_adapters = nn.ModuleDict()
+ self.is_tp = is_tp
+ self.colwise = colwise
+ self.all_reduce = all_reduce
+
+ def _get_io_features(self, in_features: int, out_features: int,
+ colwise: bool):
+ """get io features."""
+ world_size, rank = get_world_rank()
+ if colwise:
+ out_features = get_distribute_size(out_features, world_size, rank)
+ else:
+ in_features = get_distribute_size(in_features, world_size, rank)
+ return in_features, out_features
+
+ def _weight_loader_tp_colwise(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, rank: int,
+ world_size: int):
+ """weight loader for colwise linear."""
+ weight = loaded_weight.chunk(world_size, 0)[rank]
+ return default_weight_loader(param, weight)
+
+ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, rank: int,
+ world_size: int):
+ """weight loader for rowwise linear."""
+ if loaded_weight.dim() == 2:
+ weight = loaded_weight.chunk(world_size, 1)[rank]
+ return default_weight_loader(param, weight)
+ else:
+ # bias
+ if rank != 0:
+ loaded_weight = torch.zeros_like(loaded_weight)
+ return default_weight_loader(param, loaded_weight)
+
+ def weight_loader(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor):
+ """weight loader."""
+ if not self.is_tp:
+ return default_weight_loader(param, loaded_weight)
+
+ world_size, rank = get_world_rank()
+ if self.colwise:
+ return self._weight_loader_tp_colwise(param, loaded_weight, rank,
+ world_size)
+ else:
+ return self._weight_loader_tp_rowwise(param, loaded_weight, rank,
+ world_size)
+
+ def create_weights(self, in_features: int, out_features: int, bias: bool,
+ dtype: torch.dtype, device: torch.device):
+ """create weights."""
+ weight = torch.empty((out_features, in_features),
+ dtype=self.fp8_dtype,
+ device=device)
+ scale = torch.empty(
+ (div_up(out_features,
+ self.block_size), div_up(in_features, self.block_size)),
+ dtype=torch.float32,
+ device=device)
+ if bias:
+ bias = torch.empty((out_features, ), dtype=dtype, device=device)
+ else:
+ bias = None
+ return weight, scale, bias
+
+ def update_weights(self):
+ """update weights."""
+ weight, scale, bias = self.impl.update_weights(self.weight, self.scale,
+ self.bias)
+ weight = torch.nn.Parameter(weight, requires_grad=False)
+ self.weight.weight_loader = self.weight_loader
+ scale = torch.nn.Parameter(scale, requires_grad=False)
+ self.scale.weight_loader = self.weight_loader
+ if bias is not None:
+ bias = torch.nn.Parameter(bias, requires_grad=False)
+ self.bias.weight_loader = self.weight_loader
+ self.register_parameter('weight', weight)
+ self.register_parameter('scale', scale)
+ self.register_parameter('bias', bias)
+
+ def forward(self, x):
+ """forward of blocked fp8 linear."""
+ all_reduce = False if self.colwise else self.is_tp
+ all_reduce = all_reduce and self.all_reduce
+ if len(self.lora_adapters) == 0:
+ return self.impl.forward(x, self.weight, self.scale, self.bias,
+ all_reduce)
+
+ out = self.impl.forward(x, self.weight, self.scale, self.bias, False)
+ for lora_adapter in self.lora_adapters.values():
+ out = lora_adapter(x, out)
+ if all_reduce:
+ dist.all_reduce(out)
+ return out
+
+
+class MergedBlockedF8Linear(BlockedF8Linear):
+ """merged blocked fp8 linear."""
+
+ def __init__(self,
+ in_features: int,
+ all_out_features: List[int],
+ bias: bool,
+ fp8_dtype: torch.dtype = torch.float8_e4m3fn,
+ replicate: Optional[List[bool]] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ is_tp: bool = True,
+ out_names: Optional[List[int]] = None):
+ if replicate is None:
+ replicate = tuple(False for _ in all_out_features)
+ self.block_size = 128
+ self.split_section = all_out_features
+ self.scale_split_section = [
+ section // self.block_size for section in self.split_section
+ ]
+ all_out_features = self._update_all_out_features(
+ all_out_features, replicate)
+ self.all_out_features = all_out_features
+ self.replicate = replicate
+ if out_names is None:
+ out_names = torch.arange(len(self.all_out_features)).tolist()
+ assert len(out_names) == len(self.all_out_features)
+ self.out_names_map = dict(
+ (name, idx) for idx, name in enumerate(out_names))
+ out_features = sum(all_out_features)
+ super().__init__(in_features,
+ out_features,
+ bias,
+ dtype,
+ device,
+ fp8_dtype=fp8_dtype,
+ colwise=True,
+ is_tp=is_tp)
+ self.weight.weight_loader = self.weight_loader
+ self.scale.weight_loader = self.weight_loader
+ self.weight.weight_spliter = self.weight_spliter
+ self.scale.weight_spliter = self.weight_spliter
+ if self.bias is not None:
+ self.bias.weight_loader = self.weight_loader
+ self.bias.weight_spliter = self.weight_spliter
+
+ def _get_io_features(self, in_features: int, out_features: int,
+ colwise: bool):
+ """get io features."""
+ return in_features, out_features
+
+ def _update_all_out_features(self, all_out_features: List[int],
+ replicate: Optional[List[bool]]):
+ """update all out features."""
+ world_size, rank = get_world_rank()
+ new_all_out_features = []
+ for out_feat, rep in zip(all_out_features, replicate):
+ if rep:
+ new_all_out_features.append(out_feat)
+ new_out_feat = get_distribute_size(out_feat, world_size, rank)
+ new_all_out_features.append(new_out_feat)
+ return new_all_out_features
+
+ def weight_loader(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, shard_id: Any):
+ """weight loader."""
+ world_size, rank = get_world_rank()
+ shard_idx = self.out_names_map[shard_id]
+ if loaded_weight.dim() == 2 and loaded_weight.dtype == torch.float32:
+ all_out_features = [
+ feats // self.block_size for feats in self.all_out_features
+ ]
+ param_w = param.data.split(all_out_features, 0)[shard_idx]
+ else:
+ param_w = param.data.split(self.all_out_features, 0)[shard_idx]
+ if not self.replicate[shard_idx]:
+ loaded_weight = loaded_weight.chunk(world_size, 0)[rank]
+ param_w.copy_(loaded_weight)
+
+ def weight_spliter(self, loaded_weight: torch.Tensor):
+ """weight spliter."""
+ if loaded_weight.dim() == 2 and loaded_weight.dtype == torch.float32:
+ return loaded_weight.split(self.scale_split_section, dim=0)
+ return loaded_weight.split(self.split_section, dim=0)
+
+ def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):
+ return loaded_weight.split(self.split_section, dim=0)
+
+
class AwqLinear(nn.Module):
"""w4a16 linear."""
@@ -212,7 +438,7 @@ def __init__(
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size
- self.elem_per_int = 32 // self.w_bit
+ self.elem_per_int = 32 // w_bit
self.lora_adapters = nn.ModuleDict()
self.is_tp = is_tp
self.colwise = colwise
@@ -363,12 +589,9 @@ def __init__(self,
w_bit: int,
group_size: int,
bias: bool,
- replicate: Optional[List[bool]] = None,
device: Optional[torch.device] = None,
is_tp: bool = True,
out_names: Optional[List[int]] = None):
- if replicate is None:
- replicate = tuple(False for _ in all_out_features)
self.split_section_s = all_out_features
elem_per_int = 32 // w_bit
@@ -377,9 +600,8 @@ def __init__(self,
]
all_out_features = self._update_all_out_features(
- all_out_features, w_bit, group_size, replicate)
+ all_out_features, w_bit, group_size)
self.all_out_features = all_out_features
- self.replicate = replicate
if out_names is None:
out_names = torch.arange(len(self.all_out_features)).tolist()
assert len(out_names) == len(self.all_out_features)
@@ -414,15 +636,12 @@ def _get_io_features(self, in_features: int, out_features: int, w_bit: int,
return in_features, out_features
def _update_all_out_features(self, all_out_features: List[int], w_bit: int,
- group_size: int,
- replicate: Optional[List[bool]]):
+ group_size: int):
"""update all out features."""
world_size, rank = get_world_rank()
new_all_out_features = []
align = max(32 // w_bit, group_size)
- for out_feat, rep in zip(all_out_features, replicate):
- if rep:
- new_all_out_features.append(out_feat)
+ for out_feat in all_out_features:
new_out_feat = get_distribute_size(out_feat, world_size, rank,
align)
new_all_out_features.append(new_out_feat)
@@ -433,14 +652,11 @@ def weight_loader(self, param: torch.nn.Parameter,
"""weight loader."""
world_size, rank = get_world_rank()
shard_idx = self.out_names_map[shard_id]
-
if loaded_weight.dim() == 1:
# bias
align = max(self.elem_per_int, self.group_size)
param_w = param.data.split(self.all_out_features, 0)[shard_idx]
- if not self.replicate[shard_idx]:
- weight = _chunk_align(loaded_weight, world_size, 0,
- align)[rank]
+ weight = _chunk_align(loaded_weight, world_size, 0, align)[rank]
param_w.copy_(weight)
if param._weight_type in ['scales', 'bias']:
@@ -456,8 +672,7 @@ def weight_loader(self, param: torch.nn.Parameter,
]
param_w = param.data.split(quanted_out_feats, 1)[shard_idx]
- if not self.replicate[shard_idx]:
- weight = _chunk_align(loaded_weight, world_size, -1, align)[rank]
+ weight = _chunk_align(loaded_weight, world_size, -1, align)[rank]
param_w.copy_(weight)
def weight_spliter_wz(self, loaded_weight: torch.Tensor):
@@ -480,45 +695,82 @@ def __init__(self,
head_size_v: int,
w_bit: int,
group_size: int,
- replicate_kv: bool = False,
bias: bool = False,
device: Optional[torch.device] = None,
- is_tp: bool = True):
+ is_tp: bool = True,
+ num_replicate_kv_heads: int = 1):
self.qkv_split_section_s = self._get_qkv_out_features(
- num_q_heads, num_kv_heads, head_size, head_size_v)
+ num_q_heads, num_kv_heads, head_size, head_size_v,
+ num_replicate_kv_heads)
elem_per_int = 32 // w_bit
self.qkv_split_section_wz = [
size // elem_per_int for size in self.qkv_split_section_s
]
num_q_heads, num_kv_heads = self._update_num_heads(
- num_q_heads, num_kv_heads, replicate_kv)
+ num_q_heads, num_kv_heads)
all_out_features = self._get_qkv_out_features(num_q_heads,
num_kv_heads, head_size,
head_size_v)
- replicate = (False, replicate_kv, replicate_kv)
out_names = ('q', 'k', 'v')
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_size = head_size
self.head_size_v = head_size_v
+ self.num_replicate_kv_heads = num_replicate_kv_heads
+
super().__init__(in_features,
all_out_features,
w_bit=w_bit,
group_size=group_size,
bias=bias,
- replicate=replicate,
device=device,
is_tp=is_tp,
out_names=out_names)
def _update_all_out_features(self, all_out_features: List[int], w_bit: int,
- group_size: int,
- replicate: Optional[List[bool]]):
+ group_size: int):
"""update all out features."""
return all_out_features
+ def weight_loader(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, shard_id: Any):
+ """weight loader."""
+ world_size, rank = get_world_rank()
+ chunk_size, chunk_idx = world_size, rank
+ shard_idx = self.out_names_map[shard_id]
+
+ if self.num_replicate_kv_heads > 1 and shard_id in ['k', 'v']:
+ # update to duplicate k/v for tp_size > num_kv_heads
+ chunk_size = world_size // self.num_replicate_kv_heads
+ chunk_idx = rank // self.num_replicate_kv_heads
+
+ if loaded_weight.dim() == 1:
+ # bias
+ align = max(self.elem_per_int, self.group_size)
+ param_w = param.data.split(self.all_out_features, 0)[shard_idx]
+ weight = _chunk_align(loaded_weight, chunk_size, 0,
+ align)[chunk_idx]
+ param_w.copy_(weight)
+ return
+
+ if param._weight_type in ['scales', 'bias']:
+ # scales
+ align = max(self.elem_per_int, self.group_size)
+ param_w = param.data.split(self.all_out_features, -1)[shard_idx]
+ else:
+ # qweight or qzeros
+ align = max(self.elem_per_int,
+ self.group_size) // self.elem_per_int
+ quanted_out_feats = [
+ feat // self.elem_per_int for feat in self.all_out_features
+ ]
+ param_w = param.data.split(quanted_out_feats, 1)[shard_idx]
+
+ weight = _chunk_align(loaded_weight, chunk_size, -1, align)[chunk_idx]
+ param_w.copy_(weight)
+
def weight_spliter_wz(self,
loaded_weight: torch.Tensor,
layout: str = 'default'):
@@ -566,17 +818,16 @@ def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):
class W8A8Linear(nn.Module):
"""w8a8 linear."""
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias: bool,
- dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
- colwise: bool = True,
- is_tp: bool = False,
- all_reduce: bool = True,
- ):
+ def __init__(self,
+ in_features: int,
+ out_features: int,
+ bias: bool,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ colwise: bool = True,
+ is_tp: bool = False,
+ all_reduce: bool = True,
+ quant_dtype: Optional[torch.dtype] = torch.int8):
super().__init__()
if device is None:
device = torch.device('cpu')
@@ -586,10 +837,12 @@ def __init__(
in_features, out_features = self._get_io_features(
in_features, out_features, colwise)
impl_builder = get_backend().get_layer_impl_builder(OpType.LinearW8A8)
+ self.quant_dtype = quant_dtype
self.impl = impl_builder.build(in_features,
out_features,
bias is not None,
- dtype=dtype)
+ dtype=dtype,
+ quant_dtype=quant_dtype)
weight, scale, bias = self.create_weights(in_features, out_features,
bias, dtype, device)
weight = torch.nn.Parameter(weight, requires_grad=False)
@@ -631,7 +884,9 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, rank: int,
world_size: int):
"""weight loader for rowwise linear."""
- if loaded_weight.dim() == 2 and param.dtype == torch.int8:
+ if loaded_weight.dim() == 2 and param.dtype in (torch.int8,
+ torch.float8_e4m3fn,
+ torch.float8_e5m2):
weight = loaded_weight.chunk(world_size, 1)[rank]
return default_weight_loader(param, weight)
elif loaded_weight.dim() == 2 and loaded_weight.size(1) == 1:
@@ -661,7 +916,7 @@ def create_weights(self, in_features: int, out_features: int, bias: bool,
dtype: torch.dtype, device: torch.device):
"""create weights."""
weight = torch.empty((out_features, in_features),
- dtype=torch.int8,
+ dtype=self.quant_dtype,
device=device)
scale = torch.empty((out_features, 1),
dtype=torch.float32,
@@ -710,18 +965,14 @@ def __init__(self,
in_features: int,
all_out_features: List[int],
bias: bool,
- replicate: Optional[List[bool]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
is_tp: bool = True,
- out_names: Optional[List[int]] = None):
- if replicate is None:
- replicate = tuple(False for _ in all_out_features)
+ out_names: Optional[List[int]] = None,
+ quant_dtype: torch.dtype = torch.int8):
self.split_section = all_out_features
- all_out_features = self._update_all_out_features(
- all_out_features, replicate)
+ all_out_features = self._update_all_out_features(all_out_features)
self.all_out_features = all_out_features
- self.replicate = replicate
if out_names is None:
out_names = torch.arange(len(self.all_out_features)).tolist()
assert len(out_names) == len(self.all_out_features)
@@ -734,7 +985,8 @@ def __init__(self,
dtype,
device,
colwise=True,
- is_tp=is_tp)
+ is_tp=is_tp,
+ quant_dtype=quant_dtype)
self.weight.weight_loader = self.weight_loader
self.scale.weight_loader = self.weight_loader
self.weight.weight_spliter = self.weight_spliter
@@ -748,14 +1000,11 @@ def _get_io_features(self, in_features: int, out_features: int,
"""get io features."""
return in_features, out_features
- def _update_all_out_features(self, all_out_features: List[int],
- replicate: Optional[List[bool]]):
+ def _update_all_out_features(self, all_out_features: List[int]):
"""update all out features."""
world_size, rank = get_world_rank()
new_all_out_features = []
- for out_feat, rep in zip(all_out_features, replicate):
- if rep:
- new_all_out_features.append(out_feat)
+ for out_feat in all_out_features:
new_out_feat = get_distribute_size(out_feat, world_size, rank)
new_all_out_features.append(new_out_feat)
return new_all_out_features
@@ -766,8 +1015,7 @@ def weight_loader(self, param: torch.nn.Parameter,
world_size, rank = get_world_rank()
shard_idx = self.out_names_map[shard_id]
param_w = param.data.split(self.all_out_features, 0)[shard_idx]
- if not self.replicate[shard_idx]:
- loaded_weight = loaded_weight.chunk(world_size, 0)[rank]
+ loaded_weight = loaded_weight.chunk(world_size, 0)[rank]
param_w.copy_(loaded_weight)
def weight_spliter(self, loaded_weight: torch.Tensor):
@@ -787,38 +1035,60 @@ def __init__(self,
num_kv_heads: int,
head_size: int,
head_size_v: int,
- replicate_kv: bool = False,
bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
- is_tp: bool = True):
+ is_tp: bool = True,
+ num_replicate_kv_heads: int = 1,
+ quant_dtype: torch.dtype = torch.int8):
+
self.qkv_split_section = self._get_qkv_out_features(
- num_q_heads, num_kv_heads, head_size, head_size_v)
+ num_q_heads, num_kv_heads, head_size, head_size_v,
+ num_replicate_kv_heads)
num_q_heads, num_kv_heads = self._update_num_heads(
- num_q_heads, num_kv_heads, replicate_kv)
+ num_q_heads, num_kv_heads)
all_out_features = self._get_qkv_out_features(num_q_heads,
num_kv_heads, head_size,
head_size_v)
- replicate = (False, replicate_kv, replicate_kv)
out_names = ('q', 'k', 'v')
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_size = head_size
self.head_size_v = head_size_v
+ self.num_replicate_kv_heads = num_replicate_kv_heads
super().__init__(in_features,
all_out_features,
bias=bias,
- replicate=replicate,
dtype=dtype,
device=device,
is_tp=is_tp,
- out_names=out_names)
+ out_names=out_names,
+ quant_dtype=quant_dtype)
- def _update_all_out_features(self, all_out_features: List[int],
- replicate: Optional[List[bool]]):
+ def _update_all_out_features(self, all_out_features: List[int]):
"""update all out features."""
return all_out_features
+ def weight_loader(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, shard_id: Any):
+ """weight loader."""
+ _, rank = get_world_rank()
+ shard_idx = self.out_names_map[shard_id]
+ param_w = param.data.split(self.all_out_features, 0)[shard_idx]
+ num_head = self.num_q_heads if shard_id == 'q' \
+ else self.num_kv_heads
+ head_dim = self.head_size if shard_id in ['q', 'k'] \
+ else self.head_size_v
+ # update to duplicate k/v for tp_size > num_kv_heads
+ rank_idx = rank if shard_id == 'q' \
+ else rank // self.num_replicate_kv_heads
+ sec_start = rank_idx * num_head * head_dim
+ sec_len = num_head * head_dim
+ loaded_weight = loaded_weight.narrow(dim=0,
+ start=sec_start,
+ length=sec_len)
+ param_w.copy_(loaded_weight)
+
def weight_spliter(self,
loaded_weight: torch.Tensor,
layout: str = 'default'):
@@ -986,18 +1256,13 @@ def __init__(self,
in_features: int,
all_out_features: List[int],
bias: bool,
- replicate: Optional[List[bool]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
is_tp: bool = True,
out_names: Optional[List[int]] = None):
- if replicate is None:
- replicate = tuple(False for _ in all_out_features)
self.split_section = all_out_features
- all_out_features = self._update_all_out_features(
- all_out_features, replicate)
+ all_out_features = self._update_all_out_features(all_out_features)
self.all_out_features = all_out_features
- self.replicate = replicate
if out_names is None:
out_names = torch.arange(len(self.all_out_features)).tolist()
assert len(out_names) == len(self.all_out_features)
@@ -1022,14 +1287,11 @@ def _get_io_features(self, in_features: int, out_features: int,
"""get io features."""
return in_features, out_features
- def _update_all_out_features(self, all_out_features: List[int],
- replicate: Optional[List[bool]]):
+ def _update_all_out_features(self, all_out_features: List[int]):
"""update all out features."""
world_size, rank = get_world_rank()
new_all_out_features = []
- for out_feat, rep in zip(all_out_features, replicate):
- if rep:
- new_all_out_features.append(out_feat)
+ for out_feat in all_out_features:
new_out_feat = get_distribute_size(out_feat, world_size, rank)
new_all_out_features.append(new_out_feat)
return new_all_out_features
@@ -1040,8 +1302,7 @@ def weight_loader(self, param: torch.nn.Parameter,
world_size, rank = get_world_rank()
shard_idx = self.out_names_map[shard_id]
param_w = param.data.split(self.all_out_features, 0)[shard_idx]
- if not self.replicate[shard_idx]:
- loaded_weight = loaded_weight.chunk(world_size, 0)[rank]
+ loaded_weight = loaded_weight.chunk(world_size, 0)[rank]
param_w.copy_(loaded_weight)
def weight_spliter(self, loaded_weight: torch.Tensor):
@@ -1061,35 +1322,36 @@ def __init__(self,
num_kv_heads: int,
head_size: int,
head_size_v: int,
- replicate_kv: bool = False,
bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
- is_tp: bool = True):
+ is_tp: bool = True,
+ num_replicate_kv_heads: int = 1):
+
self.qkv_split_section = self._get_qkv_out_features(
- num_q_heads, num_kv_heads, head_size, head_size_v)
+ num_q_heads, num_kv_heads, head_size, head_size_v,
+ num_replicate_kv_heads)
num_q_heads, num_kv_heads = self._update_num_heads(
- num_q_heads, num_kv_heads, replicate_kv)
+ num_q_heads, num_kv_heads)
all_out_features = self._get_qkv_out_features(num_q_heads,
num_kv_heads, head_size,
head_size_v)
- replicate = (False, replicate_kv, replicate_kv)
out_names = ('q', 'k', 'v')
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_size = head_size
self.head_size_v = head_size_v
+ self.num_replicate_kv_heads = num_replicate_kv_heads
+
super().__init__(in_features,
all_out_features,
bias=bias,
- replicate=replicate,
dtype=dtype,
device=device,
is_tp=is_tp,
out_names=out_names)
- def _update_all_out_features(self, all_out_features: List[int],
- replicate: Optional[List[bool]]):
+ def _update_all_out_features(self, all_out_features: List[int]):
"""update all out features."""
return all_out_features
@@ -1097,15 +1359,20 @@ def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, shard_id: Any):
"""weight loader."""
world_size, rank = get_world_rank()
+ chunk_size, chunk_idx = world_size, rank
shard_idx = self.out_names_map[shard_id]
param_w = param.data.split(self.all_out_features, 0)[shard_idx]
- if not self.replicate[shard_idx]:
- if shard_idx in [0, 1]:
- loaded_weight = _chunk_align(loaded_weight, world_size, 0,
- self.head_size)[rank]
- if shard_idx == 2:
- loaded_weight = _chunk_align(loaded_weight, world_size, 0,
- self.head_size_v)[rank]
+
+ if self.num_replicate_kv_heads > 1 and shard_id in ['k', 'v']:
+ # update to duplicate k/v for tp_size > num_kv_heads
+ chunk_size = world_size // self.num_replicate_kv_heads
+ chunk_idx = rank // self.num_replicate_kv_heads
+ if shard_idx in [0, 1]:
+ loaded_weight = _chunk_align(loaded_weight, chunk_size, 0,
+ self.head_size)[chunk_idx]
+ elif shard_idx == 2:
+ loaded_weight = _chunk_align(loaded_weight, chunk_size, 0,
+ self.head_size_v)[chunk_idx]
param_w.copy_(loaded_weight)
def weight_spliter(self,
@@ -1161,6 +1428,10 @@ def build_linear(in_features: int,
)
quant_method = quant_config['quant_method']
+ quant_dtype = torch.int8
+ if 'quant_dtype' in quant_config:
+ quant_dtype = eval('torch.' + quant_config['quant_dtype'])
+
if quant_method == 'awq':
w_bit = quant_config.get('bits', 4)
group_size = quant_config.get('group_size', 128)
@@ -1176,10 +1447,28 @@ def build_linear(in_features: int,
all_reduce=all_reduce,
)
if quant_method == 'smooth_quant':
- return W8A8Linear(
+ return W8A8Linear(in_features,
+ out_features,
+ bias=bias,
+ dtype=dtype,
+ device=device,
+ colwise=colwise,
+ is_tp=is_tp,
+ all_reduce=all_reduce,
+ quant_dtype=quant_dtype)
+ elif quant_method == 'fp8':
+ fmt = quant_config.get('fmt', 'e4m3')
+ if fmt == 'e4m3':
+ fp8_dtype = torch.float8_e4m3fn
+ elif fmt == 'e5m2':
+ fp8_dtype = torch.float8_e5m2
+ else:
+ raise TypeError(f'Unsupported fp8 fmt: {fmt}')
+ return BlockedF8Linear(
in_features,
out_features,
bias=bias,
+ fp8_dtype=fp8_dtype,
dtype=dtype,
device=device,
colwise=colwise,
@@ -1260,6 +1549,10 @@ def build_merged_colwise_linear(
)
quant_method = quant_config['quant_method']
+ quant_dtype = torch.int8
+ if 'quant_dtype' in quant_config:
+ quant_dtype = eval('torch.' + quant_config['quant_dtype'])
+
if quant_method == 'awq':
w_bit = quant_config.get('bits', 4)
group_size = quant_config.get('group_size', 128)
@@ -1273,10 +1566,27 @@ def build_merged_colwise_linear(
is_tp=is_tp,
)
if quant_method == 'smooth_quant':
- return MergedW8A8Linear(
+ return MergedW8A8Linear(in_features=in_features,
+ all_out_features=all_out_features,
+ bias=bias,
+ dtype=dtype,
+ device=device,
+ is_tp=is_tp,
+ out_names=out_names,
+ quant_dtype=quant_dtype)
+ elif quant_method == 'fp8':
+ fmt = quant_config.get('fmt', 'e4m3')
+ if fmt == 'e4m3':
+ fp8_dtype = torch.float8_e4m3fn
+ elif fmt == 'e5m2':
+ fp8_dtype = torch.float8_e5m2
+ else:
+ raise TypeError(f'Unsupported fp8 fmt: {fmt}')
+ return MergedBlockedF8Linear(
in_features=in_features,
all_out_features=all_out_features,
bias=bias,
+ fp8_dtype=fp8_dtype,
dtype=dtype,
device=device,
is_tp=is_tp,
@@ -1291,12 +1601,12 @@ def build_qkv_proj(in_features: int,
num_kv_heads: int,
head_size: int,
head_size_v: int = None,
- replicate_kv: bool = False,
bias: bool = False,
quant_config: Any = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
- is_tp: bool = True):
+ is_tp: bool = True,
+ num_replicate_kv_heads: int = 1):
"""build qkv proj."""
if is_tp:
world_size, _ = get_world_rank()
@@ -1306,48 +1616,47 @@ def build_qkv_proj(in_features: int,
head_size_v = head_size
if quant_config is None:
- return QKVBaseLinear(
- in_features=in_features,
- num_q_heads=num_q_heads,
- num_kv_heads=num_kv_heads,
- head_size=head_size,
- head_size_v=head_size_v,
- replicate_kv=replicate_kv,
- bias=bias,
- dtype=dtype,
- device=device,
- is_tp=is_tp,
- )
+ return QKVBaseLinear(in_features=in_features,
+ num_q_heads=num_q_heads,
+ num_kv_heads=num_kv_heads,
+ head_size=head_size,
+ head_size_v=head_size_v,
+ bias=bias,
+ dtype=dtype,
+ device=device,
+ is_tp=is_tp,
+ num_replicate_kv_heads=num_replicate_kv_heads)
quant_method = quant_config['quant_method']
+ quant_dtype = torch.int8
+ if 'quant_dtype' in quant_config:
+ quant_dtype = eval('torch.' + quant_config['quant_dtype'])
+
if quant_method == 'awq':
w_bit = quant_config.get('bits', 4)
group_size = quant_config.get('group_size', 128)
- return QKVAwqLinear(
- in_features=in_features,
- num_q_heads=num_q_heads,
- num_kv_heads=num_kv_heads,
- head_size=head_size,
- head_size_v=head_size_v,
- replicate_kv=replicate_kv,
- w_bit=w_bit,
- group_size=group_size,
- bias=bias,
- device=device,
- is_tp=is_tp,
- )
+ return QKVAwqLinear(in_features=in_features,
+ num_q_heads=num_q_heads,
+ num_kv_heads=num_kv_heads,
+ head_size=head_size,
+ head_size_v=head_size_v,
+ w_bit=w_bit,
+ group_size=group_size,
+ bias=bias,
+ device=device,
+ is_tp=is_tp,
+ num_replicate_kv_heads=num_replicate_kv_heads)
if quant_method == 'smooth_quant':
- return QKVW8A8Linear(
- in_features=in_features,
- num_q_heads=num_q_heads,
- num_kv_heads=num_kv_heads,
- head_size=head_size,
- head_size_v=head_size_v,
- replicate_kv=replicate_kv,
- bias=bias,
- dtype=dtype,
- device=device,
- is_tp=is_tp,
- )
+ return QKVW8A8Linear(in_features=in_features,
+ num_q_heads=num_q_heads,
+ num_kv_heads=num_kv_heads,
+ head_size=head_size,
+ head_size_v=head_size_v,
+ bias=bias,
+ dtype=dtype,
+ device=device,
+ is_tp=is_tp,
+ num_replicate_kv_heads=num_replicate_kv_heads,
+ quant_dtype=quant_dtype)
else:
raise RuntimeError(f'Unsupported quant method: {quant_method}')
diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py
index 47176335c4..4921825c9a 100644
--- a/lmdeploy/pytorch/nn/moe.py
+++ b/lmdeploy/pytorch/nn/moe.py
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Optional
+from typing import Any, List, Optional
import torch
import torch.distributed as dist
@@ -8,6 +8,7 @@
from lmdeploy.pytorch.distributed import get_world_rank
from ..backends import OpType, get_backend
+from .utils import div_up
class SoftmaxTopK(nn.Module):
@@ -24,6 +25,102 @@ def forward(self, x: torch.Tensor):
return self.impl.forward(x)
+def create_mlp_weights(hidden_dim: int, ffn_dim: int, num_experts: int,
+ dtype: torch.dtype, device: torch.device):
+ """create weights."""
+ gate_up_weights = torch.empty((num_experts, ffn_dim * 2, hidden_dim),
+ dtype=dtype,
+ device=device)
+ down_weights = torch.empty((num_experts, hidden_dim, ffn_dim),
+ dtype=dtype,
+ device=device)
+ return gate_up_weights, down_weights
+
+
+def _update_args(hidden_dim: int, ffn_dim: int):
+ """update args."""
+ world_size, _ = get_world_rank()
+ assert ffn_dim % world_size == 0
+ ffn_dim = ffn_dim // world_size
+ return hidden_dim, ffn_dim
+
+
+class LinearWeights(nn.Module):
+ """fused moe linear weights."""
+
+ def __init__(self,
+ num_experts: int,
+ in_features: int,
+ out_features: int,
+ weight_type: str,
+ dtype: torch.dtype,
+ device: torch.device,
+ expert_list: List[int] = None,
+ ep: bool = False):
+ super().__init__()
+ weight = torch.empty((num_experts, out_features, in_features),
+ dtype=dtype,
+ device=device)
+ weight = torch.nn.Parameter(weight, requires_grad=False)
+ self.register_parameter('weight', weight)
+ self.ep = ep
+ self.expert_list = expert_list
+ self.weight_type = weight_type
+ self.half_out = out_features // 2
+
+ if self.ep:
+ self.expert_map = dict(
+ (eid, idx) for idx, eid in enumerate(expert_list))
+ self.weight.weight_loader = self.weight_loader_ep
+ else:
+ self.weight.weight_loader = self.weight_loader_tp
+
+ def update_weight(self, weight: torch.Tensor):
+ """update weight."""
+ weight_loader = self.weight.weight_loader
+ weight = torch.nn.Parameter(weight, requires_grad=False)
+ weight.weight_loader = weight_loader
+ self.register_parameter('weight', weight)
+
+ def weight_loader_tp(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, expert_id: int,
+ shard_id: str):
+ """weight loader."""
+ world_size, rank = get_world_rank()
+ if shard_id == 'gate':
+ param_data = param.data[expert_id, :self.half_out]
+ weight = loaded_weight.chunk(world_size, dim=0)[rank]
+ elif shard_id == 'up':
+ param_data = param.data[expert_id, self.half_out:]
+ weight = loaded_weight.chunk(world_size, dim=0)[rank]
+ elif shard_id == 'down':
+ param_data = param.data[expert_id]
+ weight = loaded_weight.chunk(world_size, dim=1)[rank]
+ else:
+ raise RuntimeError(f'Unknown shard_id: {shard_id}')
+ param_data.copy_(weight)
+
+ def weight_loader_ep(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, expert_id: int,
+ shard_id: str):
+ """weight loader."""
+ expert_list = self.expert_list
+ if expert_id not in expert_list:
+ return
+
+ expert_map = self.expert_map
+ param_id = expert_map[expert_id]
+ if shard_id == 'gate':
+ param_data = param.data[param_id, :self.half_out]
+ elif shard_id == 'up':
+ param_data = param.data[param_id, self.half_out:]
+ elif shard_id == 'down':
+ param_data = param.data[param_id]
+ else:
+ raise RuntimeError(f'Unknown shard_id: {shard_id}')
+ param_data.copy_(loaded_weight)
+
+
class FusedMoE(nn.Module):
"""fused moe."""
@@ -46,42 +143,33 @@ def __init__(self,
impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoE)
self.impl = impl_builder.build(top_k, num_experts, renormalize)
- self.expert_list = None
- self.expert_map = None
enable_ep = enable_ep and self.impl.support_ep()
if enable_ep:
world_size, rank = get_world_rank()
expert_list = self.impl.ep_expert_list(world_size, rank)
- self.expert_list = expert_list
- self.expert_map = dict(
- (eid, idx) for idx, eid in enumerate(expert_list))
num_experts = len(expert_list)
- gate_up_weights, down_weights = self.create_weights(hidden_dim,
- ffn_dim,
- num_experts,
- dtype=dtype,
- device=device)
- else:
- hidden_dim, ffn_dim = self._update_args(hidden_dim, ffn_dim)
- gate_up_weights, down_weights = self.create_weights(hidden_dim,
- ffn_dim,
- num_experts,
- dtype=dtype,
- device=device)
- gate_up_weights = torch.nn.Parameter(gate_up_weights,
- requires_grad=False)
- down_weights = torch.nn.Parameter(down_weights, requires_grad=False)
- gate_up_weights._weight_type = 'gate_up_weights'
- down_weights._weight_type = 'down_weights'
- self.register_parameter('gate_up_weights', gate_up_weights)
- self.register_parameter('down_weights', down_weights)
-
- if enable_ep:
- gate_up_weights.weight_loader = self.weight_loader_ep
- down_weights.weight_loader = self.weight_loader_ep
else:
- gate_up_weights.weight_loader = self.weight_loader_tp
- down_weights.weight_loader = self.weight_loader_tp
+ hidden_dim, ffn_dim = _update_args(hidden_dim, ffn_dim)
+ expert_list = None
+ self.expert_list = expert_list
+ self.gate_up = LinearWeights(num_experts,
+ hidden_dim,
+ ffn_dim * 2,
+ weight_type='gate_up',
+ dtype=dtype,
+ device=device,
+ expert_list=expert_list,
+ ep=enable_ep)
+ self.down = LinearWeights(
+ num_experts,
+ ffn_dim,
+ hidden_dim,
+ weight_type='down',
+ dtype=dtype,
+ device=device,
+ expert_list=expert_list,
+ ep=enable_ep,
+ )
self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim
@@ -93,83 +181,375 @@ def __init__(self,
all_reduce = False
self.all_reduce = all_reduce
- def _update_args(self, hidden_dim: int, ffn_dim: int):
- """update args."""
- world_size, _ = get_world_rank()
- assert ffn_dim % world_size == 0
- ffn_dim = ffn_dim // world_size
- return hidden_dim, ffn_dim
-
- def create_weights(self, hidden_dim: int, ffn_dim: int, num_experts: int,
- dtype: torch.dtype, device: torch.device):
- """create weights."""
- gate_up_weights = torch.empty((num_experts, ffn_dim * 2, hidden_dim),
- dtype=dtype,
- device=device)
- down_weights = torch.empty((num_experts, hidden_dim, ffn_dim),
- dtype=dtype,
- device=device)
- return gate_up_weights, down_weights
-
def update_weights(self):
"""update weights."""
- gateup_loader = self.gate_up_weights.weight_loader
- down_loader = self.down_weights.weight_loader
gate_up_weights, down_weights = self.impl.update_weights(
- self.gate_up_weights, self.down_weights)
- gate_up_weights = torch.nn.Parameter(gate_up_weights,
- requires_grad=False)
- down_weights = torch.nn.Parameter(down_weights, requires_grad=False)
- gate_up_weights.weight_loader = gateup_loader
- down_weights.weight_loader = down_loader
- gate_up_weights._weight_type = 'gate_up_weights'
- down_weights._weight_type = 'down_weights'
- self.register_parameter('gate_up_weights', gate_up_weights)
- self.register_parameter('down_weights', down_weights)
+ self.gate_up.weight, self.down.weight)
+ self.gate_up.update_weight(gate_up_weights)
+ self.down.update_weight(down_weights)
- def weight_loader_tp(self, param: torch.nn.Parameter,
- loaded_weight: torch.Tensor, expert_id: int,
- shard_id: str):
- """weight loader."""
+ def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor,
+ topk_ids: torch.LongTensor):
+ ret = self.impl.forward(hidden_states, topk_weights, topk_ids,
+ self.gate_up.weight, self.down.weight,
+ self.expert_list)
+ if self.all_reduce:
+ dist.all_reduce(ret)
+ return ret
+
+
+class LinearWeightsW8A8(LinearWeights):
+ """fused moe linear w8a8 weights."""
+
+ def __init__(self,
+ num_experts: int,
+ in_features: int,
+ out_features: int,
+ weight_type: str,
+ device: torch.device,
+ expert_list: List[int] = None,
+ ep: bool = False):
+ super().__init__(
+ num_experts=num_experts,
+ in_features=in_features,
+ out_features=out_features,
+ weight_type=weight_type,
+ dtype=torch.int8,
+ device=device,
+ expert_list=expert_list,
+ ep=ep,
+ )
+ scale = torch.empty((num_experts, out_features, 1),
+ dtype=torch.float32,
+ device=device)
+ scale = torch.nn.Parameter(scale, requires_grad=False)
+ self.register_parameter('scale', scale)
+
+ if self.ep:
+ self.scale.weight_loader = self.weight_loader_ep
+ else:
+ self.scale.weight_loader = self.weight_loader_scale_tp
+
+ def update_weight(self, weight: torch.Tensor, scale: torch.Tensor):
+ """update weight."""
+ super().update_weight(weight=weight)
+ weight_loader = self.scale.weight_loader
+ scale = torch.nn.Parameter(scale, requires_grad=False)
+ scale.weight_loader = weight_loader
+ self.register_parameter('scale', scale)
+
+ def weight_loader_scale_tp(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, expert_id: int,
+ shard_id: str):
+ """weight loader scale tp."""
world_size, rank = get_world_rank()
if shard_id == 'gate':
- param_data = param.data[expert_id, :self.ffn_dim]
+ param_data = param.data[expert_id, :self.half_out]
weight = loaded_weight.chunk(world_size, dim=0)[rank]
elif shard_id == 'up':
- param_data = param.data[expert_id, self.ffn_dim:]
+ param_data = param.data[expert_id, self.half_out:]
weight = loaded_weight.chunk(world_size, dim=0)[rank]
elif shard_id == 'down':
param_data = param.data[expert_id]
- weight = loaded_weight.chunk(world_size, dim=1)[rank]
+ weight = loaded_weight
else:
raise RuntimeError(f'Unknown shard_id: {shard_id}')
param_data.copy_(weight)
- def weight_loader_ep(self, param: torch.nn.Parameter,
- loaded_weight: torch.Tensor, expert_id: int,
- shard_id: str):
- """weight loader."""
- expert_list = self.expert_list
- if expert_id not in expert_list:
- return
- expert_map = self.expert_map
- param_id = expert_map[expert_id]
+class FusedMoEW8A8(nn.Module):
+ """fused moe w8a8."""
+
+ def __init__(self,
+ hidden_dim: int,
+ ffn_dim: int,
+ num_experts: int,
+ top_k: int,
+ renormalize: bool = False,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ all_reduce: bool = True,
+ enable_ep: bool = False):
+ super().__init__()
+ if device is None:
+ device = torch.device('cpu')
+ dtype = torch.float16 if dtype is None else dtype
+
+ impl_builder = get_backend().get_layer_impl_builder(
+ OpType.FusedMoEW8A8)
+ self.impl = impl_builder.build(top_k, num_experts, renormalize, dtype)
+
+ enable_ep = enable_ep and self.impl.support_ep()
+ if enable_ep:
+ world_size, rank = get_world_rank()
+ expert_list = self.impl.ep_expert_list(world_size, rank)
+ num_experts = len(expert_list)
+ else:
+ hidden_dim, ffn_dim = _update_args(hidden_dim, ffn_dim)
+ expert_list = None
+ self.expert_list = expert_list
+
+ self.gate_up = LinearWeightsW8A8(num_experts,
+ hidden_dim,
+ ffn_dim * 2,
+ weight_type='gate_up',
+ device=device,
+ expert_list=expert_list,
+ ep=enable_ep)
+ self.down = LinearWeightsW8A8(
+ num_experts,
+ ffn_dim,
+ hidden_dim,
+ weight_type='down',
+ device=device,
+ expert_list=expert_list,
+ ep=enable_ep,
+ )
+
+ self.hidden_dim = hidden_dim
+ self.ffn_dim = ffn_dim
+ self.num_experts = num_experts
+ self.dtype = dtype
+ self.device = device
+ world_size, _ = get_world_rank()
+ if world_size == 1:
+ all_reduce = False
+ self.all_reduce = all_reduce
+
+ def update_weights(self):
+ """update weights."""
+ (gate_up_weights, down_weights, gate_up_scale,
+ down_scale) = self.impl.update_weights(self.gate_up.weight,
+ self.down.weight,
+ self.gate_up.scale,
+ self.down.scale)
+ self.gate_up.update_weight(gate_up_weights, gate_up_scale)
+ self.down.update_weight(down_weights, down_scale)
+
+ def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor,
+ topk_ids: torch.LongTensor):
+ ret = self.impl.forward(hidden_states, topk_weights, topk_ids,
+ self.gate_up.weight, self.gate_up.scale,
+ self.down.weight, self.down.scale,
+ self.expert_list)
+ if self.all_reduce:
+ dist.all_reduce(ret)
+ return ret
+
+
+class LinearWeightsBlockedF8(LinearWeights):
+ """fused moe linear blocked fp8 weights."""
+
+ def __init__(self,
+ num_experts: int,
+ in_features: int,
+ out_features: int,
+ weight_type: str,
+ block_size: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ expert_list: List[int] = None,
+ ep: bool = False):
+ super().__init__(
+ num_experts=num_experts,
+ in_features=in_features,
+ out_features=out_features,
+ weight_type=weight_type,
+ dtype=dtype,
+ device=device,
+ expert_list=expert_list,
+ ep=ep,
+ )
+ self.block_size = block_size
+ scale = torch.empty((num_experts, div_up(
+ out_features, block_size), div_up(in_features, block_size)),
+ dtype=torch.float32,
+ device=device)
+ scale = torch.nn.Parameter(scale, requires_grad=False)
+ self.register_parameter('scale', scale)
+
+ if self.ep:
+ self.scale.weight_loader = self.weight_loader_ep
+ else:
+ self.scale.weight_loader = self.weight_loader_scale_tp
+
+ def update_weight(self, weight: torch.Tensor, scale: torch.Tensor):
+ """update weight."""
+ super().update_weight(weight=weight)
+ weight_loader = self.scale.weight_loader
+ scale = torch.nn.Parameter(scale, requires_grad=False)
+ scale.weight_loader = weight_loader
+ self.register_parameter('scale', scale)
+
+ def weight_loader_scale_tp(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, expert_id: int,
+ shard_id: str):
+ """weight loader scale tp."""
+ world_size, rank = get_world_rank()
+ block_size = self.block_size
+ half_out = self.half_out // block_size
if shard_id == 'gate':
- param_data = param.data[param_id, :self.ffn_dim]
+ param_data = param.data[expert_id, :half_out]
+ weight = loaded_weight.chunk(world_size, dim=0)[rank]
elif shard_id == 'up':
- param_data = param.data[param_id, self.ffn_dim:]
+ param_data = param.data[expert_id, half_out:]
+ weight = loaded_weight.chunk(world_size, dim=0)[rank]
elif shard_id == 'down':
- param_data = param.data[param_id]
+ param_data = param.data[expert_id]
+ weight = loaded_weight.chunk(world_size, dim=1)[rank]
else:
raise RuntimeError(f'Unknown shard_id: {shard_id}')
- param_data.copy_(loaded_weight)
+ param_data.copy_(weight)
+
+
+class FusedMoEBlockedF8(nn.Module):
+ """fused moe blocked f8."""
+
+ def __init__(self,
+ hidden_dim: int,
+ ffn_dim: int,
+ num_experts: int,
+ top_k: int,
+ renormalize: bool = False,
+ fp8_dtype: torch.dtype = torch.float8_e4m3fn,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ all_reduce: bool = True,
+ enable_ep: bool = False):
+ super().__init__()
+ if device is None:
+ device = torch.device('cpu')
+ dtype = torch.float16 if dtype is None else dtype
+ self.block_size = 128
+ impl_builder = get_backend().get_layer_impl_builder(
+ OpType.FusedMoEBlockedF8)
+ self.impl = impl_builder.build(top_k,
+ num_experts,
+ renormalize,
+ block_size=self.block_size,
+ out_dtype=dtype)
+
+ enable_ep = enable_ep and self.impl.support_ep()
+ if enable_ep:
+ world_size, rank = get_world_rank()
+ expert_list = self.impl.ep_expert_list(world_size, rank)
+ num_experts = len(expert_list)
+ else:
+ hidden_dim, ffn_dim = _update_args(hidden_dim, ffn_dim)
+ expert_list = None
+ self.expert_list = expert_list
+
+ self.gate_up = LinearWeightsBlockedF8(num_experts,
+ hidden_dim,
+ ffn_dim * 2,
+ weight_type='gate_up',
+ block_size=self.block_size,
+ dtype=fp8_dtype,
+ device=device,
+ expert_list=expert_list,
+ ep=enable_ep)
+ self.down = LinearWeightsBlockedF8(
+ num_experts,
+ ffn_dim,
+ hidden_dim,
+ weight_type='down',
+ block_size=self.block_size,
+ dtype=fp8_dtype,
+ device=device,
+ expert_list=expert_list,
+ ep=enable_ep,
+ )
+
+ self.hidden_dim = hidden_dim
+ self.ffn_dim = ffn_dim
+ self.num_experts = num_experts
+ self.dtype = dtype
+ self.device = device
+ world_size, _ = get_world_rank()
+ if world_size == 1:
+ all_reduce = False
+ self.all_reduce = all_reduce
+
+ def update_weights(self):
+ """update weights."""
+ (gate_up_weights, down_weights, gate_up_scale,
+ down_scale) = self.impl.update_weights(self.gate_up.weight,
+ self.down.weight,
+ self.gate_up.scale,
+ self.down.scale)
+ self.gate_up.update_weight(gate_up_weights, gate_up_scale)
+ self.down.update_weight(down_weights, down_scale)
def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.LongTensor):
ret = self.impl.forward(hidden_states, topk_weights, topk_ids,
- self.gate_up_weights, self.down_weights,
+ self.gate_up.weight, self.gate_up.scale,
+ self.down.weight, self.down.scale,
self.expert_list)
if self.all_reduce:
dist.all_reduce(ret)
return ret
+
+
+def build_fused_moe(
+ hidden_dim: int,
+ ffn_dim: int,
+ num_experts: int,
+ top_k: int,
+ renormalize: bool = False,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ all_reduce: bool = True,
+ enable_ep: bool = False,
+ quant_config: Any = None,
+):
+ """fused moe builder."""
+
+ if quant_config is None:
+ return FusedMoE(
+ hidden_dim=hidden_dim,
+ ffn_dim=ffn_dim,
+ num_experts=num_experts,
+ top_k=top_k,
+ renormalize=renormalize,
+ dtype=dtype,
+ device=device,
+ all_reduce=all_reduce,
+ enable_ep=enable_ep,
+ )
+
+ quant_method = quant_config['quant_method']
+ if quant_method == 'smooth_quant':
+ return FusedMoEW8A8(
+ hidden_dim=hidden_dim,
+ ffn_dim=ffn_dim,
+ num_experts=num_experts,
+ top_k=top_k,
+ renormalize=renormalize,
+ dtype=dtype,
+ device=device,
+ all_reduce=all_reduce,
+ enable_ep=enable_ep,
+ )
+ elif quant_method == 'fp8':
+ fmt = quant_config.get('fmt', 'e4m3')
+ if fmt == 'e4m3':
+ fp8_dtype = torch.float8_e4m3fn
+ elif fmt == 'e5m2':
+ fp8_dtype = torch.float8_e5m2
+ else:
+ raise TypeError(f'Unsupported fp8 fmt: {fmt}')
+ return FusedMoEBlockedF8(
+ hidden_dim=hidden_dim,
+ ffn_dim=ffn_dim,
+ num_experts=num_experts,
+ top_k=top_k,
+ renormalize=renormalize,
+ fp8_dtype=fp8_dtype,
+ dtype=dtype,
+ device=device,
+ all_reduce=all_reduce,
+ enable_ep=enable_ep,
+ )
+ else:
+ raise RuntimeError(f'Unsupported quant method: {quant_method}')
diff --git a/lmdeploy/pytorch/nn/norm.py b/lmdeploy/pytorch/nn/norm.py
index ef244ff73f..7e2c820399 100644
--- a/lmdeploy/pytorch/nn/norm.py
+++ b/lmdeploy/pytorch/nn/norm.py
@@ -4,19 +4,23 @@
import torch
from torch import nn
+from lmdeploy.pytorch.distributed import get_world_rank
+
from ..backends import OpType, get_backend
+from .utils import chunk_aligned, get_distribute_size
def _is_w8a8(quant_config: Any):
"""is w8a8."""
- if quant_config is None:
- return False
- else:
+ quant_dtype = None
+ w8a8_flag = False
+ if quant_config is not None:
quant_method = quant_config['quant_method']
- if quant_method == 'w8a8':
- return True
- else:
- return False
+ if quant_method == 'smooth_quant':
+ w8a8_flag = True
+ quant_dtype = quant_config.get('quant_dtype', 'int8')
+ quant_dtype = eval(f'torch.{quant_dtype}')
+ return w8a8_flag, quant_dtype
class RMSNorm(nn.Module):
@@ -27,16 +31,44 @@ def __init__(self,
eps: float = 1e-6,
dtype: torch.dtype = None,
device: torch.device = None,
- quant_config: Any = None):
+ quant_config: Any = None,
+ tp: bool = False,
+ align: int = 1):
super().__init__()
backend = get_backend()
- if _is_w8a8(quant_config):
+
+ w8a8_flag, quant_dtype = _is_w8a8(quant_config)
+ if w8a8_flag:
builder = backend.get_layer_impl_builder(OpType.RMSNormW8A8)
else:
builder = backend.get_layer_impl_builder(OpType.RMSNorm)
+
+ if tp:
+ world_size, rank = get_world_rank()
+ hidden_size = get_distribute_size(hidden_size,
+ world_size,
+ rank,
+ align=align)
+
self.register_parameter('weight',
self.create_weight(hidden_size, dtype, device))
- self.impl = builder.build(hidden_size, eps)
+ if w8a8_flag:
+ self.impl = builder.build(hidden_size,
+ eps,
+ quant_dtype=quant_dtype)
+ else:
+ self.impl = builder.build(hidden_size, eps)
+
+ if tp:
+ self.weight.weight_loader = self.weight_loader
+ self.align = align
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
+ """weight loader."""
+ world_size, rank = get_world_rank()
+ loaded_weight = chunk_aligned(loaded_weight, world_size, 0,
+ self.align)[rank]
+ param.copy_(loaded_weight)
@staticmethod
def create_weight(hidden_size: int,
diff --git a/lmdeploy/pytorch/nn/utils.py b/lmdeploy/pytorch/nn/utils.py
index 3289f858a7..085b12c3e9 100644
--- a/lmdeploy/pytorch/nn/utils.py
+++ b/lmdeploy/pytorch/nn/utils.py
@@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+
def div_up(a: int, b: int):
"""div up."""
return (a + b - 1) // b
@@ -11,7 +14,26 @@ def get_distribute_size(feature_size: int,
"""update feature size."""
assert feature_size % align == 0
aligned_size = feature_size // align
- align_per_rank = div_up(aligned_size, world_size)
- prev_feats = align_per_rank * rank
- updated_aligned_size = min(align_per_rank, aligned_size - prev_feats)
+ # try to make every rank has same amount of feats
+ updated_aligned_size = aligned_size // world_size
+ # if there are still some remain, given them to
+ # each rank
+ if rank < aligned_size % world_size:
+ updated_aligned_size += 1
return updated_aligned_size * align
+
+
+def chunk_aligned(weight: torch.Tensor, chunks: int, dim: int, align: int):
+ """chunk aligned."""
+ if align == 1:
+ return weight.chunk(chunks, dim=dim)
+ size = weight.size(dim)
+ assert size % align == 0
+ aligned_size = size // align
+
+ # try best to evenly split chunks
+ align_per_chunk = aligned_size // chunks
+ remain = aligned_size % chunks
+ sections = [align_per_chunk + int(c < remain) for c in range(chunks)]
+ sections = [sec * align for sec in sections]
+ return weight.split(sections, dim=dim)
diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py
index e28e375965..0d901d75a3 100644
--- a/lmdeploy/pytorch/paging/scheduler.py
+++ b/lmdeploy/pytorch/paging/scheduler.py
@@ -274,11 +274,19 @@ def has_unfinished(self):
return self.has_running() or self.has_waiting()
def has_running(self):
- return self.seq_manager.num_sequences(MessageStatus.RUNNING) > 0
+ return self.num_running() > 0
def has_waiting(self):
- return self.seq_manager.num_sequences(MessageStatus.WAITING) > 0
+ return self.num_waiting() > 0
def get_block_tables(self, seqs: SeqList):
"""get block table of the sequences."""
return [self.block_manager.get_block_table(seq) for seq in seqs]
+
+ def num_running(self):
+ """num running."""
+ return self.seq_manager.num_sequences(MessageStatus.RUNNING)
+
+ def num_waiting(self):
+ """num waiting."""
+ return self.seq_manager.num_sequences(MessageStatus.WAITING)
diff --git a/lmdeploy/pytorch/supported_models.py b/lmdeploy/pytorch/supported_models.py
index 7fa568651b..67452f78e3 100644
--- a/lmdeploy/pytorch/supported_models.py
+++ b/lmdeploy/pytorch/supported_models.py
@@ -47,9 +47,9 @@
# cogvlm-chat
CogVLMForCausalLM=True,
# llava
- LlavaLlamaForCausalLM=True,
+ LlavaLlamaForCausalLM=False,
# llava mistral
- LlavaMistralForCausalLM=True,
+ LlavaMistralForCausalLM=False,
# deepseekvl
MultiModalityCausalLM=False,
# StarCoder2
diff --git a/lmdeploy/pytorch/tools/make_inputs.py b/lmdeploy/pytorch/tools/make_inputs.py
index f2d23830b7..053e7d0918 100644
--- a/lmdeploy/pytorch/tools/make_inputs.py
+++ b/lmdeploy/pytorch/tools/make_inputs.py
@@ -135,6 +135,7 @@ def __fill_kv_caches(kv_caches, past_key_values, block_offsets):
return StepContext.new(
inputs=model_inputs,
+ model_config=model_config,
world_size=world_size,
kv_caches=kv_caches,
)
diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py
index 69bead8906..86c0936de2 100644
--- a/lmdeploy/serve/async_engine.py
+++ b/lmdeploy/serve/async_engine.py
@@ -231,9 +231,10 @@ def __call__(self,
"""Inference a batch of prompts.
Args:
- prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
- prompts. It accepts: string prompt, a list of string prompts,
- a chat history in OpenAI format or a list of chat history.
+ prompts (List[str] | str | List[Dict] | List[List[Dict]]]): a
+ batch of prompts. It accepts: string prompt, a list of string
+ prompts, a chat history in OpenAI format or a list of chat
+ history.
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
@@ -305,9 +306,10 @@ def batch_infer(self,
"""Inference a batch of prompts.
Args:
- prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
- prompts. It accepts: string prompt, a list of string prompts,
- a chat history in OpenAI format or a list of chat history.
+ prompts (List[str] | str | List[Dict] | List[List[Dict]]]): a
+ batch of prompts. It accepts: string prompt, a list of string
+ prompts, a chat history in OpenAI format or a list of chat
+ history.
gen_config (GenerationConfig | None): a instance of or a list of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
@@ -382,9 +384,10 @@ def stream_infer(
"""Inference a batch of prompts with stream mode.
Args:
- prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
- prompts. It accepts: string prompt, a list of string prompts,
- a chat history in OpenAI format or a list of chat history.
+ prompts (List[str] | str | List[Dict] | List[List[Dict]]]):a
+ batch of prompts. It accepts: string prompt, a list of string
+ prompts, a chat history in OpenAI format or a list of chat
+ history.
gen_config (GenerationConfig | None): a instance of or a list of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
@@ -510,8 +513,8 @@ async def generate(
if gen_config.stop_token_ids is None:
gen_config.stop_token_ids = self.stop_words
if not gen_config.do_sample:
- logger.warn(f'GenerationConfig: {gen_config}')
- logger.warn(
+ logger.warning(f'GenerationConfig: {gen_config}')
+ logger.warning(
'Since v0.6.0, lmdeploy add `do_sample` in '
'GenerationConfig. It defaults to False, meaning greedy '
'decoding. Please set `do_sample=True` if sampling '
diff --git a/lmdeploy/serve/gradio/vl.py b/lmdeploy/serve/gradio/vl.py
index 103bcc5889..bf8ee87e68 100644
--- a/lmdeploy/serve/gradio/vl.py
+++ b/lmdeploy/serve/gradio/vl.py
@@ -70,8 +70,6 @@ def run_local(model_path: str,
**kwargs):
from lmdeploy.serve.vl_async_engine import VLAsyncEngine
- if isinstance(backend_config, PytorchEngineConfig):
- backend_config.thread_safe = True
vision_config = VisionConfig(thread_safe=True)
engine = VLAsyncEngine(model_path=model_path,
model_name=model_name,
@@ -115,10 +113,13 @@ def chat(chatbot, session, max_new_tokens, top_p, top_k, temperature):
else:
prompt = history[-1][0][0]
images = history[-1][0][1:]
- prompt = (prompt, images)
-
- logger.info('prompt: ' + str(prompt))
- prompt = engine.vl_prompt_template.prompt_to_messages(prompt)
+ # convert prompt into GPT4V format
+ prompt = [
+ dict(role='user', content=[dict(type='text', text=prompt)])
+ ]
+ for image in images:
+ prompt[0]['content'].append(
+ dict(type='image_data', image_data=dict(data=image)))
t0 = time.perf_counter()
inputs = _run_until_complete(
engine._get_prompt_input(prompt, True, sequence_start, ''))
diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py
index f515e49d2e..eb424a0829 100644
--- a/lmdeploy/serve/openai/api_server.py
+++ b/lmdeploy/serve/openai/api_server.py
@@ -509,7 +509,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
for call_info in call_info_list
]
except Exception as e:
- logger.error(f'Exception: {e}')
+ logger.error(f'Failed to parse {text}. Exception: {e}.')
return create_error_response(
HTTPStatus.BAD_REQUEST,
'Failed to parse fc related info to json format!')
@@ -946,6 +946,20 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
return JSONResponse(ret)
+def handle_torchrun():
+ """To disable mmengine logging logic when using torchrun."""
+
+ def dummy_get_device_id():
+ return 0
+
+ if int(os.environ.get('LOCAL_RANK', -1)) > 0:
+ from lmdeploy.vl.model.utils import _set_func
+
+ # the replacement can't be recovered
+ _set_func('mmengine.logging.logger._get_device_id',
+ dummy_get_device_id)
+
+
@router.on_event('startup')
async def startup_event():
if VariableInterface.proxy_url is None:
@@ -1071,8 +1085,8 @@ def serve(model_path: str,
ssl_certfile = os.environ['SSL_CERTFILE']
http_or_https = 'https'
+ handle_torchrun()
_, pipeline_class = get_task(model_path)
-
VariableInterface.async_engine = pipeline_class(
model_path=model_path,
speculative_model=speculative_model,
diff --git a/lmdeploy/serve/proxy/constants.py b/lmdeploy/serve/proxy/constants.py
index 88d86a3e33..5bf6e67659 100644
--- a/lmdeploy/serve/proxy/constants.py
+++ b/lmdeploy/serve/proxy/constants.py
@@ -2,8 +2,8 @@
import enum
-LATENCY_DEEQUE_LEN = 15
-API_TIMEOUT_LEN = 100
+LATENCY_DEQUE_LEN = 15
+API_READ_TIMEOUT = 100
class Strategy(enum.Enum):
diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py
index 5f05930bd0..392ede3267 100644
--- a/lmdeploy/serve/proxy/proxy.py
+++ b/lmdeploy/serve/proxy/proxy.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import asyncio
import copy
import json
import os
@@ -18,14 +19,15 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
+from requests.exceptions import RequestException
from lmdeploy.serve.openai.api_server import (check_api_key,
create_error_response)
from lmdeploy.serve.openai.protocol import ( # noqa: E501
ChatCompletionRequest, CompletionRequest, ModelCard, ModelList,
ModelPermission)
-from lmdeploy.serve.proxy.constants import (API_TIMEOUT_LEN,
- LATENCY_DEEQUE_LEN, ErrorCodes,
+from lmdeploy.serve.proxy.constants import (API_READ_TIMEOUT,
+ LATENCY_DEQUE_LEN, ErrorCodes,
Strategy, err_msg)
from lmdeploy.utils import get_logger
@@ -36,7 +38,7 @@ class Status(BaseModel):
"""Status protocol consists of models' information."""
models: Optional[List[str]] = Field(default=[], examples=[[]])
unfinished: int = 0
- latency: Deque = Field(default=deque(maxlen=LATENCY_DEEQUE_LEN),
+ latency: Deque = Field(default=deque(maxlen=LATENCY_DEQUE_LEN),
examples=[[]])
speed: Optional[int] = Field(default=None, examples=[None])
@@ -87,6 +89,9 @@ def __init__(self,
with open(self.config_path, 'r') as config_file:
self.nodes = yaml.safe_load(config_file)['nodes']
for url, status in self.nodes.items():
+ latency = deque(status.get('latency', []),
+ maxlen=LATENCY_DEQUE_LEN)
+ status['latency'] = latency
status = Status(**status)
self.nodes[url] = status
self.heart_beat_thread = threading.Thread(target=heart_beat_controller,
@@ -99,7 +104,7 @@ def update_config_file(self):
nodes = copy.deepcopy(self.nodes)
for url, status in nodes.items():
nodes[url] = status.model_dump()
- nodes[url]['latency'] = list(status.latency)
+ nodes[url]['latency'] = list(status.latency)[-LATENCY_DEQUE_LEN:]
with open(self.config_path, 'w') as config_file: # update cfg yml
yaml.dump(dict(nodes=nodes), config_file)
@@ -149,7 +154,8 @@ def remove_stale_nodes_by_expiration(self):
to_be_deleted.append(node_url)
for node_url in to_be_deleted:
self.remove(node_url)
- logger.info(f'Removed node_url: {node_url}')
+ logger.info(f'Removed node_url: {node_url} '
+ 'due to heart beat expiration')
@property
def model_list(self):
@@ -251,7 +257,7 @@ def handle_unavailable_model(self, model_name):
Args:
model_name (str): the model in the request.
"""
- logger.info(f'no model name: {model_name}')
+ logger.warning(f'no model name: {model_name}')
ret = {
'error_code': ErrorCodes.MODEL_NOT_FOUND,
'text': err_msg[ErrorCodes.MODEL_NOT_FOUND],
@@ -260,51 +266,54 @@ def handle_unavailable_model(self, model_name):
def handle_api_timeout(self, node_url):
"""Handle the api time out."""
- logger.info(f'api timeout: {node_url}')
+ logger.warning(f'api timeout: {node_url}')
ret = {
- 'error_code': ErrorCodes.API_TIMEOUT,
+ 'error_code': ErrorCodes.API_TIMEOUT.value,
'text': err_msg[ErrorCodes.API_TIMEOUT],
}
return json.dumps(ret).encode() + b'\n'
- def stream_generate(self, request: Dict, node_url: str, node_path: str):
+ def stream_generate(self, request: Dict, node_url: str, endpoint: str):
"""Return a generator to handle the input request.
Args:
request (Dict): the input request.
node_url (str): the node url.
- node_path (str): the node path. Such as `/v1/chat/completions`.
+ endpoint (str): the endpoint. Such as `/v1/chat/completions`.
"""
try:
response = requests.post(
- node_url + node_path,
+ node_url + endpoint,
json=request,
- stream=request['stream'],
- timeout=API_TIMEOUT_LEN,
+ stream=True,
+ timeout=(5, API_READ_TIMEOUT),
)
for chunk in response.iter_lines(decode_unicode=False,
delimiter=b'\n'):
if chunk:
yield chunk + b'\n\n'
- except requests.exceptions.RequestException as e: # noqa
+ except (Exception, GeneratorExit, RequestException) as e: # noqa
+ logger.error(f'catched an exception: {e}')
+ # exception happened, reduce unfinished num
yield self.handle_api_timeout(node_url)
- async def generate(self, request: Dict, node_url: str, node_path: str):
+ async def generate(self, request: Dict, node_url: str, endpoint: str):
"""Return a the response of the input request.
Args:
request (Dict): the input request.
node_url (str): the node url.
- node_path (str): the node path. Such as `/v1/chat/completions`.
+ endpoint (str): the endpoint. Such as `/v1/chat/completions`.
"""
try:
import httpx
async with httpx.AsyncClient() as client:
- response = await client.post(node_url + node_path,
+ response = await client.post(node_url + endpoint,
json=request,
- timeout=API_TIMEOUT_LEN)
+ timeout=API_READ_TIMEOUT)
return response.text
- except requests.exceptions.RequestException as e: # noqa
+ except (Exception, GeneratorExit, RequestException, asyncio.CancelledError) as e: # noqa # yapf: disable
+ logger.error(f'catched an exception: {e}')
return self.handle_api_timeout(node_url)
def pre_call(self, node_url):
@@ -381,7 +390,11 @@ def add_node(node: Node, raw_request: Request = None):
RPM or other metric. All the values of nodes should be the same metric.
"""
try:
- node_manager.add(node.url, node.status)
+ res = node_manager.add(node.url, node.status)
+ if res is not None:
+ logger.error(f'add node {node.url} failed, {res}')
+ return res
+ logger.info(f'add node {node.url} successfully')
return 'Added successfully'
except: # noqa
return 'Failed to add, please check the input url.'
@@ -392,8 +405,10 @@ def remove_node(node_url: str):
"""Show available models."""
try:
node_manager.remove(node_url)
+ logger.info(f'delete node {node_url} successfully')
return 'Deleted successfully'
except: # noqa
+ logger.error(f'delete node {node_url} failed.')
return 'Failed to delete, please check the input url.'
@@ -407,28 +422,50 @@ async def chat_completions_v1(request: ChatCompletionRequest,
The request should be a JSON object with the following fields:
- model: model name. Available from /v1/models.
- - messages: string prompt or chat history in OpenAI format. A example
- for chat history is `[{"role": "user", "content":"knock knock"}]`.
+ - messages: string prompt or chat history in OpenAI format. Chat history
+ example: `[{"role": "user", "content": "hi"}]`.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
- message. Only support one here.
+ message. **Only support one here**.
- stream: whether to stream the results or not. Default to false.
- - max_tokens (int): output token nums
+ - max_tokens (int | None): output token nums. Default to None.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
+ - response_format (Dict | None): Only pytorch backend support formatting
+ response. Examples: `{"type": "json_schema", "json_schema": {"name":
+ "test","schema": {"properties": {"name": {"type": "string"}},
+ "required": ["name"], "type": "object"}}}`
+ or `{"type": "regex_schema", "regex_schema": "call me [A-Za-z]{1,10}"}`
+ - logit_bias (Dict): Bias to logits. Only supported in pytorch engine.
+ - tools (List): A list of tools the model may call. Currently, only
+ internlm2 functions are supported as a tool. Use this to specify a
+ list of functions for which the model can generate JSON inputs.
+ - tool_choice (str | object): Controls which (if any) tool is called by
+ the model. `none` means the model will not call any tool and instead
+ generates a message. Specifying a particular tool via {"type":
+ "function", "function": {"name": "my_function"}} forces the model to
+ call that tool. `auto` or `required` will put all the tools information
+ to the model.
Additional arguments supported by LMDeploy:
+ - top_k (int): The number of the highest probability vocabulary
+ tokens to keep for top-k-filtering
- ignore_eos (bool): indicator for ignoring eos
- - session_id (int): if not specified, will set random value
+ - skip_special_tokens (bool): Whether or not to remove special tokens
+ in the decoding. Default to be True.
+ - min_new_tokens (int): To generate at least numbers of tokens.
+ - min_p (float): Minimum token probability, which will be scaled by the
+ probability of the most likely token. It must be a value between
+ 0 and 1. Typical values are in the 0.01-0.2 range, comparably
+ selective as setting `top_p` in the 0.99-0.8 range (use the
+ opposite of normal `top_p` values)
Currently we do not support the following features:
- - function_call (Users should implement this by themselves)
- - logit_bias (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
@@ -439,6 +476,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,
if not node_url:
return node_manager.handle_unavailable_model(request.model)
+ logger.info(f'A request is dispatched to {node_url}')
request_dict = request.model_dump()
start = node_manager.pre_call(node_url)
if request.stream is True:
@@ -465,13 +503,13 @@ async def completions_v1(request: CompletionRequest,
- model (str): model name. Available from /v1/models.
- prompt (str): the input prompt.
- suffix (str): The suffix that comes after a completion of inserted text.
- - max_tokens (int): output token nums
+ - max_tokens (int): output token nums. Default to 16.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
- message. Only support one here.
+ message. **Only support one here**.
- stream: whether to stream the results or not. Default to false.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
@@ -481,7 +519,8 @@ async def completions_v1(request: CompletionRequest,
Additional arguments supported by LMDeploy:
- ignore_eos (bool): indicator for ignoring eos
- - session_id (int): if not specified, will set random value
+ - skip_special_tokens (bool): Whether or not to remove special tokens
+ in the decoding. Default to be True.
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
@@ -497,6 +536,7 @@ async def completions_v1(request: CompletionRequest,
if not node_url:
return node_manager.handle_unavailable_model(request.model)
+ logger.info(f'A request is dispatched to {node_url}')
request_dict = request.model_dump()
start = node_manager.pre_call(node_url)
if request.stream is True:
@@ -517,6 +557,7 @@ def proxy(server_name: str = '0.0.0.0',
'min_observed_latency'] = 'min_expected_latency',
api_keys: Optional[Union[List[str], str]] = None,
ssl: bool = False,
+ log_level: str = 'INFO',
**kwargs):
"""To launch the proxy server.
@@ -540,6 +581,7 @@ def proxy(server_name: str = '0.0.0.0',
if ssl:
ssl_keyfile = os.environ['SSL_KEYFILE']
ssl_certfile = os.environ['SSL_CERTFILE']
+ logger.setLevel(log_level)
uvicorn.run(app=app,
host=server_name,
port=server_port,
diff --git a/lmdeploy/serve/vl_async_engine.py b/lmdeploy/serve/vl_async_engine.py
index c293cd71c8..becf1b76fb 100644
--- a/lmdeploy/serve/vl_async_engine.py
+++ b/lmdeploy/serve/vl_async_engine.py
@@ -1,148 +1,208 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Dict, List, Optional, Union
+import asyncio
+from typing import Dict, List, Literal, Optional, Tuple, Union
-import numpy as np
+import PIL
+from lmdeploy.messages import (PytorchEngineConfig, TurbomindEngineConfig,
+ VisionConfig)
from lmdeploy.pytorch.check_env import try_import_deeplink
from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.utils import get_logger
-from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX, IMAGE_TOKEN
from lmdeploy.vl.engine import ImageEncoder
-from lmdeploy.vl.templates import VLPromptType, get_vl_prompt_template
+from lmdeploy.vl.utils import load_image
logger = get_logger('lmdeploy')
+VLPromptType = Union[str, Tuple[str, PIL.Image.Image],
+ Tuple[str, List[PIL.Image.Image]]]
+
class VLAsyncEngine(AsyncEngine):
"""Visual Language Async inference engine."""
- def __init__(self, model_path: str, **kwargs) -> None:
- vision_config = kwargs.pop('vision_config', None)
- backend_config = kwargs.get('backend_config', None)
- if kwargs.get('backend', '') == 'pytorch':
+ def __init__(self,
+ model_path: str,
+ backend: Literal['turbomind', 'pytorch'] = 'turbomind',
+ backend_config: Optional[Union[TurbomindEngineConfig,
+ PytorchEngineConfig]] = None,
+ vision_config: Optional[VisionConfig] = None,
+ **kwargs) -> None:
+ if backend == 'pytorch':
try_import_deeplink(backend_config.device_type)
self.vl_encoder = ImageEncoder(model_path,
+ backend,
vision_config,
backend_config=backend_config)
- super().__init__(model_path, **kwargs)
+ super().__init__(model_path,
+ backend=backend,
+ backend_config=backend_config,
+ **kwargs)
if self.model_name == 'base':
raise RuntimeError(
'please specify chat template as guided in https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html#set-chat-template' # noqa: E501
)
- self.vl_prompt_template = get_vl_prompt_template(
- model_path, self.chat_template, self.model_name)
- def _convert_prompts(self,
+ @classmethod
+ def _convert_prompts(cls,
prompts: Union[VLPromptType, List[Dict],
List[VLPromptType], List[List[Dict]]]):
- """convert prompts to openai format."""
+ """convert prompts to openai GPT4V format."""
if isinstance(prompts, str) or isinstance(prompts, tuple):
- _prompts = self.vl_prompt_template.prompt_to_messages(prompts)
+ _prompts = cls.prompt_to_messages(prompts)
elif isinstance(prompts[0], tuple) or isinstance(prompts[0], str):
- _prompts = [
- self.vl_prompt_template.prompt_to_messages(x) for x in prompts
- ]
+ _prompts = [cls.prompt_to_messages(x) for x in prompts]
else:
_prompts = prompts
return _prompts
async def _get_prompt_input(self,
- prompt: Dict,
+ messages: Union[str, List[Dict]],
do_preprocess: bool,
sequence_start: bool,
adapter_name: str,
tools: Optional[List[object]] = None,
**kwargs):
- """get input_ids, embeddings and offsets."""
- if do_preprocess:
- decorated = self.vl_prompt_template.messages2prompt(
- prompt, sequence_start)
- else:
- decorated = prompt
- segs = decorated.split(IMAGE_TOKEN)
-
- results = {}
- input_ids = []
- from lmdeploy.vl.templates import (MllamaTempateWrapper,
- MolmoChatTemplateWrapper,
- Qwen2VLChatTemplateWrapper)
- ranges = None
- grid_thws = None
- if len(segs) > 1:
- # yapf: disable
- images_with_kwargs = await self.vl_prompt_template.async_collect_pil_images(prompt) # noqa: E501
- # yapf: enable
- features = []
- if len(images_with_kwargs) > 0:
- images, image_kwargs = list(zip(*images_with_kwargs))
- features = await self.vl_encoder.async_infer(
- images, image_kwargs)
-
- from lmdeploy.vl.templates import MiniCPMVTempateWrapper
- if isinstance(self.vl_prompt_template, MiniCPMVTempateWrapper):
- decorated, features = self.vl_prompt_template.update_image_token( # noqa: E501
- decorated, features)
- segs = decorated.split(IMAGE_TOKEN)
-
- if isinstance(self.vl_prompt_template,
- Qwen2VLChatTemplateWrapper):
- grid_thws = [x['grid_thw'] for x in features]
- features = [x['embeddings'] for x in features]
-
- if isinstance(self.vl_prompt_template, MllamaTempateWrapper):
- # llama3.2 just encode <|image|> and inference
- decorated = decorated.replace(IMAGE_TOKEN, '<|image|>')
- input_ids = self.tokenizer.encode(decorated,
- add_bos=sequence_start)
- results['input_ids'] = input_ids
- results['prompt'] = decorated
- assert len(features)
- results['cross_attention_states'] = features[0]
- return results
-
- if isinstance(self.vl_prompt_template,
- MolmoChatTemplateWrapper):
- return features[0]
-
- features = [x.cpu().numpy() for x in features]
- input_ids = []
- begins = []
- ends = []
- if len(segs) != len(features) + 1:
- logger.error(
- f'the number of {IMAGE_TOKEN} is not equal '
- f'to input images, {len(segs) - 1} vs {len(features)}')
- features = features[:len(segs) - 1]
- for i, seg in enumerate(segs):
- if i > 0 and i <= len(features):
- image_dim = features[i - 1].shape[0]
- begins.append(len(input_ids))
- ends.append(begins[-1] + image_dim)
- input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim)
- seg_ids = self.tokenizer.encode(seg,
- add_bos=((i == 0)
- and sequence_start))
- input_ids.extend(seg_ids)
- ranges = np.stack([begins, ends], axis=1).tolist()
- results['input_embeddings'] = features or None
- results['input_embedding_ranges'] = ranges or None
+ """process messages and return the required data for the inference
+ engines.
+
+ Refer to pytorch.engine.EngineInstance.async_stream_infer and
+ turbomind.TurboMindInstance.async_stream_infer for the argument
+ specification.
+ """
+ if isinstance(messages, str):
+ return await super()._get_prompt_input(messages, do_preprocess,
+ sequence_start,
+ adapter_name, tools,
+ **kwargs)
+ elif isinstance(messages, List):
+ has_multimodal_input = any(
+ isinstance(message['content'], list) and any(
+ item['type'] in ['image_url', 'image_data']
+ for item in message['content']) for message in messages)
+ if not has_multimodal_input:
+ return await super()._get_prompt_input(messages, do_preprocess,
+ sequence_start,
+ adapter_name, tools,
+ **kwargs)
else:
- input_ids = self.tokenizer.encode(decorated,
- add_bos=sequence_start)
-
- if isinstance(self.vl_prompt_template, Qwen2VLChatTemplateWrapper):
- # TODO: refactor _get_prompt_input function
- mrope_position_ids, mrope_position_delta = \
- self.vl_prompt_template.get_mrope_info(
- len(input_ids), grid_thws=grid_thws,
- embedding_ranges=ranges)
- results['mrope_position_ids'] = mrope_position_ids
- results['mrope_position_delta'] = mrope_position_delta
-
- results['input_ids'] = input_ids
- results['prompt'] = decorated
+ raise RuntimeError(f'unsupported messages {messages}')
+
+ messages = await self.async_convert_to_pil_images(messages)
+ results = await self.vl_encoder.preprocess(messages)
+ if self.backend == 'turbomind':
+ # for tm engine, this module perform vision embedding after image
+ # preprocessing. It utilizes the hf model's vision embeddings
+ # functions and returns the input_ids, input_embeddings,
+ # embedding_ranges and so on. All the returned values are passed
+ # to tm engine for token generation
+ results = await self.vl_encoder.async_infer(results)
+ results = await self.vl_encoder.wrap_for_turbomind(
+ results, self.chat_template, self.tokenizer, sequence_start)
+ elif self.backend == 'pytorch':
+ # for pt engine, this module only conduct the image preprocessing
+ # It leaves the vision embedding to the pt engine
+ results = await self.vl_encoder.wrap_for_pytorch(
+ results, self.chat_template, self.tokenizer, sequence_start)
return results
+ @classmethod
+ async def async_convert_to_pil_images(cls,
+ messages: List[Dict]) -> List[Dict]:
+ """Scan the provided messages to find image URLs or base64-encoded
+ image data. Loads the images into Pillow image objects.
+
+ Args:
+ messages (List[Dict]): a user request of GPT4V message format
+ """
+ if isinstance(messages, Dict):
+ messages = [messages]
+ assert isinstance(messages, List)
+
+ out_messages = [None] * len(messages)
+
+ def _inner_call(i, in_messages, out_messages):
+ role = in_messages[i]['role']
+ content = in_messages[i]['content']
+ assert role in ['system', 'user', 'assistant'], \
+ f'unsupported role "{role}"'
+ if role != 'user' or isinstance(content, str):
+ # the content is a user's prompt or an assistant's prompt,
+ # returning it directly
+ out_messages[i] = in_messages[i]
+ return
+ # the role is a user and the content is a list, in which there
+ # might be image_url or image_data
+ assert isinstance(content, List)
+ message = dict(role=role, content=[])
+ for item in content:
+ # image url or base64-encoded image data
+ if item['type'] == 'image_url':
+ """
+ convert the following item:
+ {
+ 'type': 'image_url',
+ 'image_url': {
+ 'url': 'image url or base64-encoded image data',
+ 'key': 'value' # parameters used in image processing
+ ...
+ }
+ }
+ to:
+ {
+ 'type': 'image',
+ 'image': Pillow.Image,
+ 'key': 'value' # parameters used in image processing
+ ...
+ }
+ """ # noqa
+ data = item['image_url'].copy()
+ try:
+ url = data.pop('url')
+ image = load_image(url)
+ data.update(type='image', image=image)
+ message['content'].append(data)
+ except KeyError:
+ logger.error(f'invalid format {message}')
+ elif item['type'] == 'image_data':
+ """
+ convert the following item:
+ {
+ 'type': 'image_data',
+ 'image_data': {
+ 'data': Pillow.Image,
+ 'key': 'value' # parameters used in image processing
+ ...
+ }
+ }
+ to:
+ {
+ 'type': 'image',
+ 'image': Pillow.Image,
+ 'key': 'value' # parameters used in image processing
+ ...
+ }
+ """ # noqa
+ data = item['image_data'].copy()
+ try:
+ image = data.pop('data')
+ data.update(type='image', image=image)
+ message['content'].append(data)
+ except KeyError:
+ logger.error(f'invalid format {message}')
+ elif item['type'] == 'text':
+ message['content'].append(item)
+ else:
+ logger.error(f'unexpected content type {message}')
+ out_messages[i] = message
+
+ await asyncio.gather(*[
+ asyncio.get_event_loop().run_in_executor(None, _inner_call, i,
+ messages, out_messages)
+ for i in range(len(messages))
+ ])
+ return out_messages
+
def batch_infer(self, prompts: Union[VLPromptType, List[Dict],
List[VLPromptType], List[List[Dict]]],
**kwargs):
@@ -173,3 +233,46 @@ def chat(self, prompts: VLPromptType, **kwargs):
last_round = sess.history[-1]
sess.history[-1] = (prompts, last_round[-1])
return sess
+
+ @classmethod
+ def prompt_to_messages(cls, prompt: VLPromptType):
+ """convert prompt to GTP4V format."""
+ messages = {
+ 'role': 'user',
+ 'content': [{
+ 'type': 'text',
+ 'text': '',
+ }]
+ }
+ if isinstance(prompt, str):
+ messages['content'][0]['text'] = prompt
+ else:
+ prompt, images = prompt
+ if not isinstance(images, list):
+ images = [images]
+ messages['content'][0]['text'] = prompt
+ for image in images:
+ # 'image_url': means url or local path to image.
+ # 'image_data': means PIL.Image.Image object.
+ if isinstance(image, str):
+ image = load_image(image)
+ item = {
+ 'type': 'image_data',
+ 'image_data': {
+ 'data': image
+ }
+ }
+ elif isinstance(image, PIL.Image.Image):
+ item = {
+ 'type': 'image_data',
+ 'image_data': {
+ 'data': image
+ }
+ }
+ else:
+ raise ValueError(
+ 'image should be a str(url/path) or PIL.Image.Image')
+
+ messages['content'].append(item)
+
+ return [messages]
diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py
index 77f0bc8dc8..176c3191f4 100644
--- a/lmdeploy/turbomind/deploy/converter.py
+++ b/lmdeploy/turbomind/deploy/converter.py
@@ -6,7 +6,7 @@
import fire
import torch
-from lmdeploy.archs import get_model_arch
+from lmdeploy.archs import get_model_arch, search_nested_config
from lmdeploy.messages import TurbomindEngineConfig
from lmdeploy.model import MODELS, best_match_model
from lmdeploy.utils import get_logger, get_model
@@ -129,16 +129,17 @@ def get_output_model_registered_name_and_config(model_path: str,
] else 'float16'
elif dtype in ['float16', 'bfloat16']:
if weight_type == 'int4':
- logger.warn(f'The model {model_path} is a quantized model, so the '
- f'specified data type {dtype} is ignored')
+ logger.warning(
+ f'The model {model_path} is a quantized model, so the '
+ f'specified data type {dtype} is ignored')
else:
weight_type = dtype
else:
assert 0, f'unsupported specified data type {dtype}'
if weight_type == 'bfloat16' and not is_bf16_supported():
- logger.warn('data type fallback to float16 since '
- 'torch.cuda.is_bf16_supported is False')
+ logger.warning('data type fallback to float16 since '
+ 'torch.cuda.is_bf16_supported is False')
weight_type = 'float16'
config.model_config.model_arch = model_arch
config.model_config.weight_type = weight_type
@@ -174,23 +175,6 @@ def pack_model_repository(workspace_path: str):
dst=osp.join(model_repo_dir, 'postprocessing'))
-def find_quantization_config(nested, target_key):
- if isinstance(nested, dict):
- for key, value in nested.items():
- if key == target_key:
- return value
- if isinstance(value, (dict, list)):
- result = find_quantization_config(value, target_key)
- if result is not None:
- return result
- elif isinstance(nested, list):
- for item in nested:
- result = find_quantization_config(item, target_key)
- if result is not None:
- return result
- return None
-
-
def get_tm_model(model_path,
model_name,
chat_template_name,
@@ -213,8 +197,7 @@ def get_tm_model(model_path,
If it is None, the turbomind model won't be saved
"""
_, cfg = get_model_arch(model_path)
- quant_config = find_quantization_config(cfg.to_dict(),
- 'quantization_config')
+ quant_config = search_nested_config(cfg.to_dict(), 'quantization_config')
if quant_config:
quant_method = quant_config.get('quant_method')
_group_size = int(quant_config.get('group_size', 0))
diff --git a/lmdeploy/turbomind/deploy/module.py b/lmdeploy/turbomind/deploy/module.py
index 52497175ef..1754161ff5 100644
--- a/lmdeploy/turbomind/deploy/module.py
+++ b/lmdeploy/turbomind/deploy/module.py
@@ -191,7 +191,7 @@ def __init__(self, model: BaseOutputModel):
self.attn_bias = model.model_config.attn_bias
def _reorder_and_merge(self, qkvo):
- q, k, v, o = map(transpose, qkvo)
+ q, k, v, o = qkvo
# reorder output dim for tm's rotary embedding layout
if self.model.permute_qk:
q = permute_v2(q, self.head_dim)
@@ -202,6 +202,27 @@ def _reorder_and_merge(self, qkvo):
o = torch.zeros_like(q)
return qkv, o
+ def _repeat_kv(self, qkvo, kind: str):
+ """replicate kv."""
+ q, k, v, o = qkvo
+ head_dim = self.model.model_config.size_per_head
+ hidden_dim = self.model.model_config.hidden_units
+
+ def _repeat(x):
+ dim = hidden_dim if kind != 'bias' else 1
+ x = x.reshape(dim, -1, head_dim)
+ x = x.repeat(1, 1, self.model.repeat_kv)
+ x = x.reshape(dim, -1)
+ return x
+
+ k, v = map(_repeat, (k, v))
+ if kind == 'bias':
+ if o is None:
+ o = torch.zeros(hidden_dim, dtype=q.dtype, device=q.device)
+ q, k, v, o = map(torch.squeeze, (q, k, v, o))
+
+ return (q, k, v, o)
+
def _export(self, idx: int, qkvo, kind: str, pack_fn, **kwargs):
if all(x is None for x in qkvo):
return
@@ -209,6 +230,9 @@ def _export(self, idx: int, qkvo, kind: str, pack_fn, **kwargs):
if is_lora_a:
qkv, o = map(transpose, qkvo)
else:
+ qkvo = tuple(map(transpose, qkvo))
+ if self.model.repeat_kv:
+ qkvo = self._repeat_kv(qkvo, kind)
qkv, o = self._reorder_and_merge(qkvo)
self.model.save_split(pack_fn(qkv),
self._attn.format(idx, 'w_qkv', kind),
diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py
index f2c981bb24..7ea1a84f35 100644
--- a/lmdeploy/turbomind/deploy/target_model/base.py
+++ b/lmdeploy/turbomind/deploy/target_model/base.py
@@ -78,6 +78,17 @@ def __init__(self,
self.model_config.expert_inter_size = _pad_inter_size(
self.model_config.expert_inter_size,
self.model_config.group_size, self.tensor_para_size)
+
+ # head_num is divisble by tp but kv_head_num is not
+ # and tp is divisble by kv_head_num
+ assert self.model_config.head_num % self.tensor_para_size == 0
+ self.repeat_kv = 0
+ if (self.tensor_para_size > self.model_config.kv_head_num and
+ self.tensor_para_size % self.model_config.kv_head_num == 0):
+ self.repeat_kv = (self.tensor_para_size //
+ self.model_config.kv_head_num)
+ self.model_config.kv_head_num = self.tensor_para_size
+
self.model_config.verify()
assert self.model_config.kv_head_num % self.tensor_para_size == 0
diff --git a/lmdeploy/turbomind/supported_models.py b/lmdeploy/turbomind/supported_models.py
index 11e99edfa0..2b9c5156ed 100644
--- a/lmdeploy/turbomind/supported_models.py
+++ b/lmdeploy/turbomind/supported_models.py
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from lmdeploy.archs import get_model_arch
+from lmdeploy.archs import get_model_arch, search_nested_config
from lmdeploy.utils import get_logger
logger = get_logger('lmdeploy')
@@ -80,7 +80,12 @@ def _is_head_dim_supported(cfg):
if os.path.exists(triton_model_path):
support_by_turbomind = True
else:
+
arch, cfg = get_model_arch(model_path)
+ quant_method = search_nested_config(cfg.to_dict(), 'quant_method')
+ if quant_method and quant_method in ['smooth_quant']:
+ # tm hasn't support quantized models by applying smoothquant
+ return False
if arch in SUPPORTED_ARCHS.keys():
support_by_turbomind = True
diff --git a/lmdeploy/version.py b/lmdeploy/version.py
index d9f4307a78..0b4b8a5379 100644
--- a/lmdeploy/version.py
+++ b/lmdeploy/version.py
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
-__version__ = '0.6.3'
+__version__ = '0.6.5'
short_version = __version__
diff --git a/lmdeploy/vl/engine.py b/lmdeploy/vl/engine.py
index 124fd537c6..7d490b2b77 100644
--- a/lmdeploy/vl/engine.py
+++ b/lmdeploy/vl/engine.py
@@ -1,13 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
-import inspect
-import queue
-import time
-from threading import Thread
+from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union
import torch
-from PIL.Image import Image
from lmdeploy.messages import (PytorchEngineConfig, TurbomindEngineConfig,
VisionConfig)
@@ -27,169 +23,95 @@ def _raise_exception_on_finish(task: asyncio.Task) -> None:
raise e
-class Record:
- """Batching manager."""
-
- def __init__(self, thread_safe):
- self.thread_safe = thread_safe
- self.number = []
- self.waiting = []
- self.kwargs = []
- self.done = []
- self.res_que = []
- self.total = 0
-
- def enqueue(self, images: List[Image], kwargs: List[Dict],
- que: Union[queue.Queue, asyncio.Queue]):
- """add ith request to manager."""
- self.number.append(len(images))
- self.waiting.extend(images)
- self.kwargs.extend(kwargs)
- self.res_que.append(que)
- self.total += len(images)
- self.log('received', len(images))
-
- def dequeue(self, max_batch_size):
- """try to dequeue max batch size images."""
- inputs = self.waiting[:max_batch_size]
- kwargs = self.kwargs[:max_batch_size]
- self.waiting = self.waiting[max_batch_size:]
- self.kwargs = self.kwargs[max_batch_size:]
- self.total -= len(inputs)
- self.log('process', len(inputs))
- return inputs, kwargs
-
- def notify(self):
- """set result if request i is finished."""
- if len(self.number) == 0 or self.number[0] > len(self.done):
- return False
- num_images = self.number.pop(0)
- outputs = self.done[:num_images]
- self.done = self.done[num_images:]
- que = self.res_que.pop(0)
- self.log('done', num_images)
- if self.thread_safe:
- que._loop.call_soon_threadsafe(que.put_nowait, outputs)
- else:
- que.put_nowait(outputs)
- return True
-
- def log(self, task: str, num: int):
- logger.info(f'ImageEncoder {task} {num} images, '
- f'left {self.total} images.')
-
-
class ImageEncoder:
"""Image encoder."""
- def __init__(self,
- model_path: str,
- vision_config: VisionConfig = None,
- backend_config: Optional[Union[TurbomindEngineConfig,
- PytorchEngineConfig]] = None):
- self.model = load_vl_model(model_path, backend_config=backend_config)
+ def __init__(
+ self,
+ model_path: str,
+ backend: str,
+ vision_config: VisionConfig = None,
+ backend_config: Optional[Union[TurbomindEngineConfig,
+ PytorchEngineConfig]] = None,
+ ):
+ self.model = load_vl_model(model_path,
+ backend,
+ backend_config=backend_config)
if vision_config is None:
vision_config = VisionConfig()
self.vision_config = vision_config
self.max_batch_size = vision_config.max_batch_size
+ self.executor = ThreadPoolExecutor(max_workers=1)
torch.cuda.empty_cache()
- self._que: asyncio.Queue = None
- self._loop_task: asyncio.Task = None
- if vision_config.thread_safe:
- self._create_thread_safe_task()
-
- def _create_thread_safe_task(self):
- """thread safe loop task."""
- self._loop = asyncio.new_event_loop()
- def _work_thread():
- asyncio.set_event_loop(self._loop)
- self._que = asyncio.Queue()
- self._loop.run_until_complete(self._forward_loop())
-
- thread = Thread(target=_work_thread, daemon=True)
- thread.start()
- self._loop_thread = thread
-
- def _create_event_loop_task(self):
- """event loop task."""
- task = asyncio.get_event_loop().create_task(self._forward_loop())
- self._loop_task = task
- self._loop = task.get_loop()
-
- @property
- def req_que(self):
- if self.vision_config.thread_safe:
- return self._que
- if self._que is None:
- self._que = asyncio.Queue()
- if self._loop_task is None:
- self._create_event_loop_task()
- if asyncio.get_event_loop() != self._loop:
- raise RuntimeError('Current event loop is different from'
- ' the one bound to loop task!')
- return self._que
-
- async def _forward_loop(self):
- """working loop to process images."""
- logger.info('start ImageEncoder._forward_loop')
- record = Record(self.vision_config.thread_safe)
- while True:
- while record.total == 0 or (self._que.qsize() and
- record.total < self.max_batch_size):
- while self._que.qsize() == 0:
- await asyncio.sleep(0.01)
- item = await self._que.get()
- record.enqueue(item[0], item[1], item[2])
- inputs, kwargs = record.dequeue(self.max_batch_size)
- future = asyncio.get_event_loop().run_in_executor(
- None, self.forward, inputs, kwargs)
- future.add_done_callback(_raise_exception_on_finish)
- outputs = await future
- record.done.extend(outputs)
- while record.notify():
- pass
-
- def _init_input_params(self,
- inputs: List[Image],
- params: List[Dict] = None):
- """Check and init inputs params."""
- if params is None:
- params = [{}] * len(inputs)
- assert len(params) == len(inputs), \
- 'different length of inputs and kwargs'
- return params
-
- def forward(self, inputs: List[Image], params: List[Dict] = None):
- """Model forward."""
- params = self._init_input_params(inputs, params)
- time_start = time.perf_counter()
- func_params = inspect.signature(self.model.forward).parameters
- func_inputs = [inputs, params] if len(func_params) > 1 else [inputs]
- outputs = self.model.forward(*func_inputs)
- if isinstance(outputs[0], torch.Tensor):
- outputs = [x.cpu() for x in outputs]
- time_end = time.perf_counter()
- logger.info(f'ImageEncoder forward {len(inputs)} images, '
- f'cost {time_end - time_start:.3f}s')
+ async def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """preprocess multimodal data in the messages."""
+ future = asyncio.get_event_loop().run_in_executor(
+ self.executor, self.model.preprocess, messages)
+ future.add_done_callback(_raise_exception_on_finish)
+ outputs = await future
return outputs
- def infer(self, inputs: List[Image], params: List[Dict] = None):
- """infer."""
- params = self._init_input_params(inputs, params)
- results = self.forward(inputs, params)
- return results
+ async def async_infer(self, messages: List[Dict]) -> List[Dict]:
+ """get multimodal embedding.
+
+ Args:
+ messages (List[Dict]): a list of message, which is the output
+ of `preprocess()`
+ """
+ future = asyncio.get_event_loop().run_in_executor(
+ self.executor, self.model.forward, messages, self.max_batch_size)
+ future.add_done_callback(_raise_exception_on_finish)
+ outputs = await future
+ return outputs
- async def async_infer(self,
- inputs: List[Image],
- params: List[Dict] = None):
- """async infer."""
- params = self._init_input_params(inputs, params)
- outputs = asyncio.Queue()
- item = (inputs, params, outputs)
- if self.vision_config.thread_safe:
- self._loop.call_soon_threadsafe(self._que.put_nowait, item)
- else:
- self.req_que.put_nowait(item)
- results = await outputs.get()
- return results
+ async def wrap_for_pytorch(self, messages: List[Dict], chat_template,
+ tokenizer, sequence_start) -> List[Dict]:
+ """
+ Args:
+ messages (List[Dict]): a list of message, which is supposed to be
+ the output of `preprocess`
+ Returns:
+ a dict which will be passed to pytorch engine_instance's forward.
+ The dict is like the following:
+ Dict(
+ 'prompt': 'the prompt after applying chat template'
+ 'input_ids': [],
+ 'multimodal': {
+ 'pixel_values': torch.Tensor,
+ ...
+ ]
+ )
+ """
+ result = self.model.to_pytorch(messages, chat_template, tokenizer,
+ sequence_start)
+ # clear data
+ for i, message in enumerate(messages):
+ if isinstance(message['content'], List):
+ messages[i]['preprocess'] = None
+ return result
+
+ async def wrap_for_turbomind(self, messages: List[Dict], chat_template,
+ tokenizer, sequence_start) -> Dict:
+ """
+ Args:
+ messages (List[Dict]): a list of message, which is supposed to be
+ the output of `async_infer`
+ Returns:
+ a dict which will be passed to pytorch engine_instance's forward.
+ The dict is like the following:
+ Dict(
+ 'prompt': 'the prompt after applying chat template'
+ 'input_ids': [],
+ 'input_embeddings': list[torch.Tensor],
+ 'input_embedding_ranges': list[torch.Tensor],
+ ...
+ """
+ result = self.model.to_turbomind(messages, chat_template, tokenizer,
+ sequence_start)
+ # clear data
+ for i, message in enumerate(messages):
+ if isinstance(message['content'], List):
+ messages[i]['preprocess'] = None
+ messages[i]['forward'] = None
+ return result
diff --git a/lmdeploy/vl/model/base.py b/lmdeploy/vl/model/base.py
index 9c5f5f6e6a..0ee22b4688 100644
--- a/lmdeploy/vl/model/base.py
+++ b/lmdeploy/vl/model/base.py
@@ -2,8 +2,7 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Union
-import PIL
-import torch
+import numpy as np
from mmengine import Registry
from transformers import AutoConfig
@@ -20,35 +19,227 @@ def __init__(self,
model_path: str,
with_llm: bool = False,
max_memory: Dict[int, int] = None,
- hf_config: AutoConfig = None):
+ hf_config: AutoConfig = None,
+ backend: str = ''):
"""init."""
self.model_path = model_path
self.with_llm = with_llm
self.max_memory = max_memory
+ self.backend = backend
if hf_config is None:
_, hf_config = get_model_arch(model_path)
self.hf_config = hf_config
- self.build_model()
@abstractmethod
- def build_model():
- """build model."""
+ def build_preprocessor(self, ):
+ """build the preprocessor.
+
+ NOTE: When the derived class implements this method, try not to
+ introduce the upper stream model repo as a thirdparty package
+ """
raise NotImplementedError()
+ def build_model(self, ):
+ """build the vision part of a VLM model when backend is turbomind.
+
+ But when `with_llm=True`, load the whole VLM model
+ """
+ if self.backend == 'turbomind' or self.with_llm:
+ raise NotImplementedError()
+
@abstractmethod
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """preprocess multimodal data in the messages. The derived class,
+ i.e., a specific vision model, takes the charge of image preprocessing
+ and the result management.
+ It can integrate the result into the messages list, or insert it to
+ the individual image item.
+ Args:
+ message(Dict): multimodal data in a dict, which is as follows:
+ [
+ {'role': 'user', 'content': 'user prompt'},
+ {'role': 'assisant', 'content': 'AI reponse'},
+ {
+ 'role': 'user',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'string',
+ },
+ {
+ 'type': 'image',
+ 'image': pillow.Image,
+ 'key1': value1,
+ ...
+ },
+ {
+ 'type': 'image',
+ 'image': pillow.Image,
+ 'key1': value1,
+ ...
+ },
+ ...
+ ]
+ }
+ {....}
+ ]
+ Returns:
+ the message list with preprocessing results included, which is
+ determined by the derived classes
+ """ # noqa
+ raise NotImplementedError()
+
def forward(self,
- images: List[PIL.Image.Image],
- image_kwargs: List[Dict] = None) -> List[torch.Tensor]:
- """extract image feature.
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
Args:
- images (List[PIL.Image.Image]): input images
- image_kwargs (List[Dict]): input kwargs for each images
-
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
Return:
- List[torch.Tensor]: extract image feature for each input image
+ the message list with forwarding results included, which is
+ determined by the derived classes
"""
- raise NotImplementedError()
+ if self.backend == 'turbomind':
+ raise NotImplementedError()
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ """pack the preprocessing results in a format compatible with what is
+ required by pytorch engine. ONLY implement it when the backend is
+ pytorch engine.
+
+ Args:
+ messages(List[Dict]): the output of `preprocess`
+ chat_template: the chat template defined in `lmdeploy/model.py`
+ tokenzer: the tokenizer model
+ sequence_start: starting flag of a sequence
+ """
+ if self.backend == 'pytorch':
+ raise NotImplementedError()
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ """pack the forwarding results in a format compatible with what is
+ required by turbomind engine. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the output of `preprocess`
+ chat_template: the chat template defined in `lmdeploy/model.py`
+ tokenzer: the tokenizer model
+ sequence_start: starting flag of a sequence
+ """
+ if self.backend == 'turbomind':
+ raise NotImplementedError()
+
+ @staticmethod
+ def collect_images(messages):
+ """gather all images along with their respective parameters from the
+ messages and compile them into a single list. Each image is converted
+ to RGB color space.
+
+ Args:
+ messages (List[Tuple[Image, Dict]]): a list of images with their
+ corresponding parameters
+ """ # noqa
+ images = []
+ for message in messages:
+ content = message['content']
+ if not isinstance(content, List):
+ continue
+ images.extend([
+ (x['image'],
+ {k: v
+ for k, v in x.items() if k not in {'type', 'image'}})
+ for x in content if x['type'] == 'image'
+ ])
+ return images
+
+ @staticmethod
+ def to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start):
+ """auxiliary function to pack the preprocessing results in a format
+ compatible with what is required by pytorch engine.
+
+ Args:
+ messages(List[Dict]): the output of `preprocess`
+ prompt(str): the prompt after applying chat template
+ IMAGE_TOKEN(str): a placeholder where image tokens will be
+ inserted
+ tokenzer: the tokenizer model
+ sequence_start: starting flag of a sequence
+ """
+ # collect all preprocessing result from messages
+ preps = [x['content'] for x in messages if x['role'] == 'preprocess']
+ assert len(preps) == 1
+ preps = preps[0]
+
+ # split prompt into segments and validate data
+ segs = prompt.split(IMAGE_TOKEN)
+ assert len(segs) == len(preps) + 1, (
+ f'the number of {IMAGE_TOKEN} is not equal '
+ f'to input images, {len(segs) - 1} vs {len(preps)}')
+
+ # calculate the image token offset for each image
+ input_ids = []
+ for i, seg in enumerate(segs):
+ if i > 0 and i <= len(preps):
+ preps[i - 1].update(offset=len(input_ids))
+ image_tokens = preps[i - 1]['image_tokens']
+ image_token_id = preps[i - 1]['image_token_id']
+ input_ids.extend([image_token_id] * image_tokens)
+ token_ids = tokenizer.encode(seg,
+ add_bos=((i == 0) and sequence_start))
+ input_ids.extend(token_ids)
+
+ return dict(prompt=prompt, input_ids=input_ids, multimodal=preps)
+
+ @staticmethod
+ def to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start):
+ """auxiliary function to pack the forwarding results in a format
+ compatible with what is required by turbomind engine.
+
+ Args:
+ messages(List[Dict]): the output of `preprocess`
+ prompt(str): the prompt after applying chat template
+ IMAGE_TOKEN(str): a placeholder where image tokens will be
+ inserted
+ tokenzer: the tokenizer model
+ sequence_start: starting flag of a sequence
+ """
+ # collect image features from messages
+ features = [x['content'] for x in messages if x['role'] == 'forward']
+ features = features[0]
+ features = [x.cpu().numpy() for x in features]
+
+ # split prompt into segments and validate data
+ segs = prompt.split(IMAGE_TOKEN)
+ assert len(segs) == len(features) + 1, (
+ f'the number of {IMAGE_TOKEN} is not equal '
+ f'to input images, {len(segs) - 1} vs {len(features)}')
+
+ # tokenizer prompt, and get input_embeddings and input_embedding_ranges
+ input_ids = []
+ begins = []
+ ends = []
+ IMAGE_DUMMY_TOKEN_INDEX = 0
+ for i, seg in enumerate(segs):
+ if i > 0 and i <= len(features):
+ image_dim = features[i - 1].shape[0]
+ begins.append(len(input_ids))
+ ends.append(begins[-1] + image_dim)
+ input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim)
+ seg_ids = tokenizer.encode(seg,
+ add_bos=((i == 0) and sequence_start))
+ input_ids.extend(seg_ids)
+ ranges = np.stack([begins, ends], axis=1).tolist()
+ return dict(prompt=prompt,
+ input_ids=input_ids,
+ input_embeddings=features,
+ input_embedding_ranges=ranges)
@classmethod
def match(cls, config: AutoConfig):
diff --git a/lmdeploy/vl/model/builder.py b/lmdeploy/vl/model/builder.py
index 2401b42259..00e668c034 100644
--- a/lmdeploy/vl/model/builder.py
+++ b/lmdeploy/vl/model/builder.py
@@ -2,6 +2,8 @@
import os
from typing import Optional, Union
+import torch
+
from lmdeploy.archs import get_model_arch
from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig
from lmdeploy.utils import get_logger, get_model
@@ -29,6 +31,7 @@
def load_vl_model(model_path: str,
+ backend: str,
with_llm: bool = False,
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None):
@@ -36,8 +39,9 @@ def load_vl_model(model_path: str,
Args:
model_path(str): the path or repo_id from model hub of the model
- with_llm(bool): whether to remove the LLM part from the model.
- When it is False, it means removing LLM part
+ backend(str): the name of inference backend
+ with_llm(bool): load LLM model or not. Set it to False for VLM
+ inference scenarios and True for VLM quantization
backend_config: the config of the inference engine
"""
if not os.path.exists(model_path):
@@ -49,7 +53,6 @@ def load_vl_model(model_path: str,
max_memory = None
if not with_llm:
- import torch
tp = getattr(backend_config, 'tp', 1)
max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(tp)}
@@ -57,30 +60,21 @@ def load_vl_model(model_path: str,
kwargs = dict(model_path=model_path,
with_llm=with_llm,
max_memory=max_memory,
- hf_config=hf_config)
+ hf_config=hf_config,
+ backend=backend)
for name, module in VISION_MODELS.module_dict.items():
try:
if module.match(hf_config):
logger.info(f'matching vision model: {name}')
- return module(**kwargs)
- except Exception:
- logger.error(f'matching vision model: {name} failed')
+ model = module(**kwargs)
+ model.build_preprocessor()
+ # build the vision part of a VLM model when backend is
+ # turbomind, or load the whole VLM model when `with_llm==True`
+ if backend == 'turbomind' or with_llm:
+ model.build_model()
+ return model
+ except Exception as e:
+ logger.error(f'build vision model {name} failed, {e}')
raise
raise ValueError(f'unsupported vl model with config {hf_config}')
-
-
-def vl_model_with_tokenizer(model_path: str, with_llm: bool = True):
- """load visual model."""
- vl_model = load_vl_model(model_path, with_llm).vl_model
- llm = vl_model
- if hasattr(vl_model, 'language_model'): # deepseek vl
- llm = vl_model.language_model
- if hasattr(vl_model, 'llm'): # MiniCPMV
- llm = vl_model.llm
- llm.config.use_cache = False
- llm.half().eval()
- from transformers import AutoTokenizer
- tokenizer = AutoTokenizer.from_pretrained(model_path,
- trust_remote_code=True)
- return vl_model, llm, tokenizer
diff --git a/lmdeploy/vl/model/cogvlm.py b/lmdeploy/vl/model/cogvlm.py
index ea5a06159e..07d97153f9 100644
--- a/lmdeploy/vl/model/cogvlm.py
+++ b/lmdeploy/vl/model/cogvlm.py
@@ -1,13 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
-import warnings
-from typing import List
-
-import torch
-from PIL.Image import Image
-from transformers import AutoModelForCausalLM
+from typing import Dict, List
+from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
-from lmdeploy.vl.model.utils import disable_logging
+
+logger = get_logger('lmdeploy')
@VISION_MODELS.register_module()
@@ -16,7 +13,7 @@ class CogVLMVisionModel(VisonModel):
_arch = 'CogVLMForCausalLM'
- def build_model(self):
+ def build_preprocessor(self):
from torchvision import transforms
self.image_transform = transforms.Compose([
transforms.Resize(
@@ -26,57 +23,73 @@ def build_model(self):
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
+ image_size = self.hf_config.vision_config['image_size']
+ patch_size = self.hf_config.vision_config['patch_size']
+ self.n_token_per_image = 2 + (image_size // patch_size // 2)**2
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
- from accelerate.utils import get_balanced_memory, infer_auto_device_map
- with init_empty_weights(), warnings.catch_warnings():
- model = AutoModelForCausalLM.from_config(self.hf_config,
- trust_remote_code=True)
- if not self.with_llm:
- del model.lm_head
- for key in ['layers', 'norm', 'embed_tokens']:
- setattr(model.model, key, None)
- else:
- self.vl_model = model
+ def build_model(self):
+ if self.with_llm:
+ from transformers import AutoModelForCausalLM
+ self.vl_model = AutoModelForCausalLM.from_pretrained(
+ self.model_path, device_map='cpu', trust_remote_code=True)
+ else:
+ raise NotImplementedError('turbomind has not supported cogvlm yet')
- no_split_module_classes = ['TransformerLayer']
- max_memory = get_balanced_memory(
- model,
- max_memory=self.max_memory,
- dtype=torch.half,
- no_split_module_classes=no_split_module_classes)
- device_map = infer_auto_device_map(
- model,
- no_split_module_classes=no_split_module_classes,
- max_memory=max_memory,
- dtype=torch.half)
- same_device_keys = [('model.vision.linear_proj', 'model.vision.boi',
- 'model.vision.eoi')]
- for keys in same_device_keys:
- keys = [k for k in keys if k in device_map]
- if len(keys) <= 1:
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to the spec of `super().preprocess`"""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, _ in images:
+ image = image.convert('RGB')
+ pixel_values = self.image_transform(image)
+ outputs.append(
+ dict(pixel_values=pixel_values,
+ image_size=image.size,
+ image_tokens=self.n_token_per_image,
+ image_token_id=0))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
continue
- for k in keys[1:]:
- device_map[k] = device_map[keys[0]]
+ content = [
+ x['text'] for x in message['content'] if x['type'] == 'text'
+ ]
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+
+ prompt_messages.append(
+ dict(role='user', content=content[0], num_images=n_images))
- with disable_logging():
- load_checkpoint_and_dispatch(
- model=model,
- checkpoint=self.model_path,
- device_map=device_map if not self.with_llm else {'': 'cpu'},
- no_split_module_classes=no_split_module_classes,
- dtype=torch.half)
- self.model = model.model.vision
- self.model.eval()
+ from lmdeploy.model import Vicuna
+ llm_chat_template = Vicuna(eoa=chat_template.eoa,
+ stop_words=chat_template.stop_words)
+ prompt = ''
+ IMAGE_TOKEN = ''
+ for i, msg in enumerate(prompt_messages):
+ num_images = msg.pop('num_images', 0)
+ if num_images == 0:
+ role = msg['role']
+ msg = llm_chat_template.messages2prompt([msg], sequence_start
+ and i == 0)
+ msg = dict(role=role, content=msg)
+ prompt_i = chat_template.messages2prompt([msg], sequence_start
+ and i == 0)
+ if num_images > 0:
+ prompt_i = (IMAGE_TOKEN * num_images) + prompt_i
+ prompt += prompt_i
+ return prompt, IMAGE_TOKEN
- @torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- outputs = [x.convert('RGB') for x in images]
- outputs = [self.image_transform(x) for x in outputs]
- outputs = torch.stack(outputs, dim=0).to(device='cuda:0',
- dtype=torch.half)
- outputs = self.model(outputs)
- outputs = torch.split(outputs, 1, dim=0)
- outputs = [x.squeeze() for x in outputs]
- return outputs
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/deepseek.py b/lmdeploy/vl/model/deepseek.py
index bfbf03f01e..9780744cf2 100644
--- a/lmdeploy/vl/model/deepseek.py
+++ b/lmdeploy/vl/model/deepseek.py
@@ -1,15 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
-
import warnings
-from typing import List
+from typing import Dict, List
import torch
-from PIL.Image import Image
from transformers import AutoModelForCausalLM
+from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import disable_logging
+logger = get_logger('lmdeploy')
+
def check_deepseek_vl_install():
"""check deepseek_vl install."""
@@ -18,8 +19,8 @@ def check_deepseek_vl_install():
except ImportError:
raise ImportError(
'To use DeepSeekVLModel, please install deepseek_vl by '
- 'pip install git+https://github.com/deepseek-ai/DeepSeek-VL.git'
- ' --no-deps')
+ '`pip install git+https://github.com/deepseek-ai/DeepSeek-VL.git'
+ ' --no-deps`')
@VISION_MODELS.register_module()
@@ -28,18 +29,22 @@ class DeepSeekVisionModel(VisonModel):
_arch = 'MultiModalityCausalLM'
- def build_model(self):
+ def build_preprocessor(self):
check_deepseek_vl_install()
- # empty init
- from accelerate import init_empty_weights
from deepseek_vl.models import VLChatProcessor
+ self.image_processor = VLChatProcessor.from_pretrained(
+ self.model_path).image_processor
+
+ def build_model(self):
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
+ from accelerate import init_empty_weights
with init_empty_weights():
warnings.simplefilter('ignore')
model = AutoModelForCausalLM.from_pretrained(self.model_path)
+ self.vl_model = model
if not self.with_llm:
del model.language_model
- else:
- self.vl_model = model
from accelerate.utils import get_balanced_memory, infer_auto_device_map
max_memory = get_balanced_memory(model,
@@ -79,23 +84,111 @@ def build_model(self):
device_map=device_map if not self.with_llm else {'': 'cpu'},
dtype=torch.half)
+ self.model = model.eval()
self.vision_model = model.vision_model.eval()
self.aligner = model.aligner.eval()
- self.image_processor = VLChatProcessor.from_pretrained(
- self.model_path).image_processor
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refers to the spec of `super.preprocess()"""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, _ in images:
+ image = image.convert('RGB')
+ pixel_values = self.image_processor(
+ [image], return_tensors='pt').pixel_values
+ outputs.append(
+ dict(
+ pixel_values=pixel_values,
+ image_size=image.size,
+ # refer to https://github.com/deepseek-ai/DeepSeek-VL/blob/main/deepseek_vl/models/processing_vlm.py # noqa
+ # which is hardcoded 576
+ image_tokens=576,
+ image_token_id=0))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
@torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- outputs = [x.convert('RGB') for x in images]
- pixel_values = self.image_processor(outputs,
- return_tensors='pt').pixel_values
- pixel_values = pixel_values.to(device=next(
- self.vision_model.parameters()).device,
- dtype=torch.float16)
- # [b x n_images, T2, D]
- images_embeds = self.aligner(self.vision_model(pixel_values))
-
- outputs = torch.split(images_embeds, 1, dim=0)
- outputs = [x.squeeze() for x in outputs]
- return outputs
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ pixel_values = torch.cat(pixel_values, dim=0)
+ pixel_values = pixel_values.to(device=next(
+ self.vision_model.parameters()).device,
+ dtype=torch.float16)
+ # [b x n_images, T2, D]
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ feats = self.aligner(self.vision_model(pixel_values))
+ feats = torch.split(feats, 1, dim=0)
+ outputs.extend([x.squeeze() for x in feats])
+ messages.append(dict(role='forward', content=outputs))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ # apply chat template to get the prompt
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
+ continue
+ content = [
+ x['text'] for x in message['content'] if x['type'] == 'text'
+ ]
+ content = content[0]
+ n_image = sum(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ n_placeholder = content.count(IMAGE_TOKEN)
+ if n_placeholder == 0:
+ logger.warning(
+ f"""for deepseek-vl model, the user should insert the {IMAGE_TOKEN}
+ to user prompt manually, please read https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html
+ for more details.""") # noqa
+ if n_placeholder != 0 and n_placeholder != n_image:
+ logger.error(
+ f'unmatched placeholder and image: {n_placeholder} vs '
+ f'{n_image}. Ignore the placeholder')
+ content = content.replace(IMAGE_TOKEN, '')
+ n_placeholder = 0
+ if n_placeholder == 0:
+ if n_image == 1:
+ content = f'{IMAGE_TOKEN}{content}'
+ else:
+ content = ''.join([
+ f'{IMAGE_TOKEN} is Figure {str(i)}.\n'
+ for i in range(n_image)
+ ]) + content
+ prompt_messages.append(dict(role='user', content=content))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/glm_4v.py b/lmdeploy/vl/model/glm_4v.py
index 34e060f4c9..813813bf09 100644
--- a/lmdeploy/vl/model/glm_4v.py
+++ b/lmdeploy/vl/model/glm_4v.py
@@ -1,77 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List
-import warnings
-from typing import List
-
-import torch
-from PIL.Image import Image
from transformers import AutoConfig
+from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
-from lmdeploy.vl.model.utils import disable_logging
+
+logger = get_logger('lmdeploy')
@VISION_MODELS.register_module()
class GLM4VisionModel(VisonModel):
"""glm-4v-9b vision model."""
- _arch = 'ChatGLMModel'
+ _arch = ['ChatGLMModel', 'ChatGLMForConditionalGeneration']
@classmethod
def match(cls, config: AutoConfig):
"""check whether the config match the model."""
arch = config.architectures[0]
- if arch == cls._arch and hasattr(config, 'vision_config'):
+ if arch in cls._arch and hasattr(config, 'vision_config'):
return True
return False
- def build_model(self):
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
- from accelerate.utils import infer_auto_device_map
+ def build_preprocessor(self):
from torchvision import transforms
-
- with init_empty_weights(), warnings.catch_warnings():
- warnings.simplefilter('ignore')
- from transformers import AutoModelForCausalLM
- model = AutoModelForCausalLM.from_config(self.hf_config,
- trust_remote_code=True)
- if not self.with_llm:
- del model.transformer.embedding
- del model.transformer.rotary_pos_emb
- del model.transformer.encoder
- del model.transformer.output_layer
- else:
- self.vl_model = model
-
- no_split_module_classes = ['TransformerLayer']
-
- device_map = infer_auto_device_map(
- model,
- no_split_module_classes=no_split_module_classes,
- max_memory=self.max_memory,
- dtype=torch.half)
-
- same_device_keys = [
- ('transformer.vision.linear_proj', 'transformer.vision.boi',
- 'transformer.vision.eoi')
- ]
- for keys in same_device_keys:
- keys = [k for k in keys if k in device_map]
- if len(keys) <= 1:
- continue
- for k in keys[1:]:
- device_map[k] = device_map[keys[0]]
-
- with disable_logging():
- load_checkpoint_and_dispatch(
- model=model,
- checkpoint=self.model_path,
- device_map=device_map if not self.with_llm else {'': 'cpu'},
- no_split_module_classes=no_split_module_classes,
- dtype=torch.half)
-
- model.eval()
- self.model = model
self.image_transform = transforms.Compose([
transforms.Resize(
(self.hf_config.vision_config['image_size'], ) * 2,
@@ -80,15 +33,65 @@ def build_model(self):
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
+ image_size = self.hf_config.vision_config['image_size']
+ patch_size = self.hf_config.vision_config['patch_size']
+ self.n_token_per_image = 2 + (image_size // patch_size // 2)**2
+
+ def build_model(self):
+ if self.with_llm:
+ from transformers import AutoModelForCausalLM
+ self.vl_model = AutoModelForCausalLM.from_pretrained(
+ self.model_path, device_map='cpu', trust_remote_code=True)
+ else:
+ raise NotImplementedError('turbomind has not supported glm4v yet')
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refers to the spec of `super.preprocess()"""
+ outputs = []
+ for message in messages:
+ if not isinstance(message['content'], List):
+ continue
+ images = [
+ x['image'] for x in message['content'] if x['type'] == 'image'
+ ]
+ if len(images) > 1:
+ logger.warning(
+ f'glm4v does not support the input of multiple images'
+ f' in a single chat round, but got {len(images)} images.')
+ # we still pass all the images to the model and let the
+ # model decide what to do
+ images = [x.convert('RGB') for x in images]
+ pixel_values = [self.image_transform(x) for x in images]
+ outputs.extend([
+ dict(pixel_values=_2,
+ image_size=_1.size,
+ image_tokens=self.n_token_per_image,
+ image_token_id=0) for _1, _2 in zip(images, pixel_values)
+ ])
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ content = message['content']
+ if isinstance(content, str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['preprocess', 'forward']:
+ continue
+ prompt = [x['text'] for x in content if x['type'] == 'text']
+ n_images = len([1 for x in content if x['type'] == 'image'])
+ prompt = ''.join([f'{IMAGE_TOKEN}\n'] * n_images) + prompt[0]
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
- @torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- outputs = [x.convert('RGB') for x in images]
- outputs = [self.image_transform(x) for x in outputs]
- outputs = torch.stack(outputs, dim=0).to(device='cuda:0',
- dtype=torch.half)
- outputs = self.model.transformer.vision(outputs)
- outputs = torch.split(outputs, 1, dim=0)
- outputs = [x.squeeze() for x in outputs]
- return outputs
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/internvl.py b/lmdeploy/vl/model/internvl.py
index fa67192f11..979b8d1a39 100644
--- a/lmdeploy/vl/model/internvl.py
+++ b/lmdeploy/vl/model/internvl.py
@@ -1,10 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
-
from typing import Dict, List
import torch
-from PIL.Image import Image
-from transformers import AutoModel, CLIPImageProcessor
+from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
@@ -80,34 +78,16 @@ class InternVLVisionModel(VisonModel):
_arch = 'InternVLChatModel'
- def build_model(self):
- """Load model."""
- from accelerate import init_empty_weights
- with init_empty_weights():
- config = self.hf_config
- # transformers below 4.37.0 may raise error about flash_attn
- config.llm_config.attn_implementation = 'eager'
- model = AutoModel.from_config(config, trust_remote_code=True)
- if not self.with_llm:
- del model.language_model
- else:
- self.vl_model = model
- model.half()
+ def __init__(self,
+ model_path: str,
+ with_llm: bool = False,
+ max_memory: Dict[int, int] = None,
+ hf_config: AutoConfig = None,
+ backend: str = ''):
+ super().__init__(model_path, with_llm, max_memory, hf_config, backend)
- from accelerate import load_checkpoint_and_dispatch
- with disable_logging():
- load_checkpoint_and_dispatch(
- model=model,
- checkpoint=self.model_path,
- device_map='auto' if not self.with_llm else {'': 'cpu'},
- max_memory=self.max_memory,
- no_split_module_classes=['InternVisionEncoderLayer'],
- dtype=torch.half)
-
- # We need eval mode to freeze the weights in model, thus,
- # avoid randomness in inference.
- self.model = model.eval()
- self.config = config
+ def build_preprocessor(self):
+ self.config = self.hf_config
dynamic_image_size = getattr(self.config, 'dynamic_image_size', False)
image_processor = None
try:
@@ -131,62 +111,180 @@ def build_model(self):
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
+ self.processor = self._preprocess_v1_5
self._forward_func = self._forward_v1_5
else:
+ self.processor = self._preprocess
self.image_processor = image_processor
self._forward_func = self._forward
- def _preprocess_v1_5(self, images: List[Image], params: List[Dict] = None):
- if params is not None:
- assert len(images) == len(
- params), 'different length of images and params'
- else:
- params = [{}] * len(images)
+ force_image_size = self.hf_config.force_image_size
+ patch_size = self.hf_config.vision_config.patch_size
+ downsample_ratio = self.hf_config.downsample_ratio
+ self.image_tokens_per_patch = int(
+ (force_image_size // patch_size)**2 * (downsample_ratio**2))
- image_res = {'low': 6, 'medium': 12, 'high': 24}
+ def build_model(self):
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
+ from accelerate import init_empty_weights
+ with init_empty_weights():
+ # transformers below 4.37.0 may raise error about flash_attn
+ self.config.llm_config.attn_implementation = 'eager'
+ model = AutoModel.from_config(self.config, trust_remote_code=True)
+ self.vl_model = model
+ if not self.with_llm:
+ del model.language_model
- outputs = []
- for image, param in zip(images, params):
- max_num = param.get('max_dynamic_patch')
- if max_num is None or not isinstance(max_num, int):
- res_key = param.get('detail', 'default')
- max_num = image_res.get(res_key, self.config.max_dynamic_patch)
- out = dynamic_preprocess(
- image,
- min_num=self.config.min_dynamic_patch,
- max_num=max_num,
- image_size=self.config.vision_config.image_size,
- use_thumbnail=self.config.use_thumbnail)
- out = [self.transform(x) for x in out]
- out = torch.stack(out) # (patch) x c x h x w
- outputs.append(out)
- return outputs
+ model.half()
+ from accelerate import load_checkpoint_and_dispatch
+ with disable_logging():
+ load_checkpoint_and_dispatch(
+ model=model,
+ checkpoint=self.model_path,
+ device_map='auto' if not self.with_llm else {'': 'cpu'},
+ max_memory=self.max_memory,
+ no_split_module_classes=['InternVisionEncoderLayer'],
+ dtype=torch.half)
+
+ # We need eval mode to freeze the weights in model, thus,
+ # avoid randomness in inference.
+ self.model = model.eval()
+
+ def _preprocess_v1_5(self, image, params=None):
+ image_res = {'low': 6, 'medium': 12, 'high': 24}
+ max_num = params.get('max_dynamic_patch')
+ if max_num is None or not isinstance(max_num, int):
+ res_key = params.get('detail', 'default')
+ max_num = image_res.get(res_key, self.config.max_dynamic_patch)
+ out = dynamic_preprocess(
+ image,
+ min_num=self.config.min_dynamic_patch,
+ max_num=max_num,
+ image_size=self.config.vision_config.image_size,
+ use_thumbnail=self.config.use_thumbnail)
+ pixel_values = [self.transform(x) for x in out]
+ # (patch) x c x h x w
+ pixel_values = torch.stack(pixel_values)
+ return pixel_values
- def _forward_v1_5(self, images: List[Image], params: List[Dict] = None):
+ def _forward_v1_5(self, inputs, max_batch_size):
"""forward for internvl-chat-v1-5."""
- outputs = self._preprocess_v1_5(images, params)
- split = [x.shape[0] for x in outputs]
- outputs = torch.cat(outputs, dim=0)
- outputs = outputs.to(self.model.device, dtype=torch.float16)
- outputs = self.model.extract_feature(outputs)
- outputs = torch.split(outputs, split, dim=0)
- outputs = [x.reshape(-1, x.shape[-1]) for x in outputs]
+ assert all(x.get('pixel_values') is not None for x in inputs)
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ split = [x.shape[0] for x in pixel_values]
+ pixel_values = torch.cat(pixel_values, dim=0)
+ pixel_values = pixel_values.to(self.model.device,
+ dtype=torch.float16)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ feats = self.model.extract_feature(pixel_values)
+ feats = torch.split(feats, split, dim=0)
+ outputs.extend([x.reshape(-1, x.shape[-1]) for x in feats])
return outputs
- def _forward(self, images: List[Image], params: List[Dict] = None):
+ def _preprocess(self, image, params=None):
"""forward for internvl-chat-v1-1, internvl-chat-v1-2."""
- pixel_values = self.image_processor(images=images,
+ pixel_values = self.image_processor(images=image,
return_tensors='pt').pixel_values
- pixel_values = pixel_values.to(self.model.device, dtype=torch.float16)
- outputs = self.model.extract_feature(pixel_values)
- outputs = torch.split(outputs, 1, dim=0)
- outputs = [x.squeeze() for x in outputs]
+ return pixel_values
+
+ def _forward(self, inputs, max_batch_size):
+ """forward for internvl-chat-v1-1, internvl-chat-v1-2."""
+ assert all(x.get('pixel_values') is not None for x in inputs)
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ pixel_values = torch.cat(pixel_values, dim=0)
+ pixel_values = pixel_values.to(self.model.device,
+ dtype=torch.float16)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ feats = self.model.extract_feature(pixel_values)
+ feats = torch.split(feats, 1, dim=0)
+ outputs.extend([x.squeeze() for x in feats])
return outputs
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refers to `super.preprocess() for spec."""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ pixel_values = self.processor(image, params)
+ image_tokens = (pixel_values.shape[0] *
+ self.image_tokens_per_patch)
+ outputs.append(
+ dict(pixel_values=pixel_values,
+ image_tokens=image_tokens,
+ image_token_id=0,
+ image_size=image.size))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
+
@torch.no_grad()
def forward(self,
- images: List[Image],
- params: List[Dict] = None) -> List[torch.Tensor]:
- """forward."""
- images = [x.convert('RGB') for x in images]
- return self._forward_func(images, params)
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = self._forward_func(inputs, max_batch_size)
+ messages.append(dict(role='forward', content=outputs))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['preprocess', 'forward']:
+ continue
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ content = [
+ x['text'] for x in message['content'] if x['type'] == 'text'
+ ]
+ prompt = content[0]
+ if IMAGE_TOKEN in prompt and f'{IMAGE_TOKEN}' not in prompt:
+ prompt = prompt.replace(f'{IMAGE_TOKEN}',
+ f'{IMAGE_TOKEN}')
+ prompt = prompt.replace('', '')
+ prompt = prompt.replace('', '')
+ prompt = prompt.replace('', '')
+ elif IMAGE_TOKEN not in prompt:
+ prompt = f'{IMAGE_TOKEN * n_images}\n' + prompt
+ else:
+ pass
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/internvl_llava.py b/lmdeploy/vl/model/internvl_llava.py
index f607082b18..17a12f71ca 100644
--- a/lmdeploy/vl/model/internvl_llava.py
+++ b/lmdeploy/vl/model/internvl_llava.py
@@ -2,14 +2,13 @@
import warnings
from contextlib import contextmanager
-from typing import List, Union
+from typing import Dict, List
import torch
-from PIL.Image import Image
from transformers import AutoConfig, AutoModelForCausalLM
from lmdeploy.utils import get_logger
-from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
+from lmdeploy.vl.model.llava import VISION_MODELS, LlavaVisionModel
from lmdeploy.vl.model.utils import rewrite_ctx
from .utils import disable_logging, disable_transformers_logging
@@ -18,14 +17,13 @@
def check_llava_install():
- """check llava install."""
try:
from llava.model.multimodal_encoder.clip_encoder import \
InternVisionModel # noqa: F401
except ImportError:
raise ImportError(
'To use LlavaVLModel, please install llava by '
- 'pip install "git+https://github.com/OpenGVLab/InternVL#subdirectory=internvl_chat_llava" --no-deps' # noqa: E501
+ '`pip install git+https://github.com/OpenGVLab/InternVL#subdirectory=internvl_chat_llava --no-deps`' # noqa: E501
)
@@ -65,7 +63,7 @@ def init_empty_vit():
@VISION_MODELS.register_module()
-class InternVLLlavaVisionModel(VisonModel):
+class InternVLLlavaVisionModel(LlavaVisionModel):
"""Llava visual model."""
@classmethod
@@ -78,9 +76,12 @@ def match(cls, config: AutoConfig):
return True
return False
+ def build_preprocessor(self):
+ return super().build_preprocessor()
+
def build_model(self):
- """build model & load weights."""
- # check llava install
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
check_llava_install()
# currently, only support llava llama
from llava.model.language_model.llava_llama import ( # noqa
@@ -98,13 +99,12 @@ def build_model(self):
} # disable vision part quantization
model = AutoModelForCausalLM.from_config(self.config,
trust_remote_code=True)
+ self.vl_model = model
if not self.with_llm:
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm
- else:
- self.vl_model = model
with init_empty_vit():
vision_tower = model.get_vision_tower()
@@ -137,42 +137,43 @@ def build_model(self):
self.vision_tower = model.model.vision_tower.eval()
self.mm_projector = model.model.mm_projector.eval()
- def encode_images(self, images: torch.Tensor) -> torch.Tensor:
- """encode images."""
- image_features = self.vision_tower(images)
- image_features = self.mm_projector(image_features)
- return image_features
-
- def preprocess(
- self,
- images: List[Image]) -> Union[torch.Tensor, List[torch.Tensor]]:
- """preprocess."""
- # TODO: gpu processor
- from llava.mm_utils import process_images
- images = [x.convert('RGB') for x in images]
- image_processor = self.vision_tower.image_processor
- outputs = process_images(images, image_processor, self.config)
- return outputs
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to `super().preprocess() for spec."""
+ return super().preprocess(messages)
@torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- images = self.preprocess(images)
- if isinstance(images, list):
- images = [
- x.to(self.vision_tower.device, dtype=torch.float16)
- for x in images
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
]
- else:
- images = images.to(self.vision_tower.device, dtype=torch.float16)
-
- if type(images) is list or images.ndim == 5:
- concat_images = torch.cat([image for image in images], dim=0)
- image_features = self.encode_images(concat_images)
- split_sizes = [image.shape[0] for image in images]
- image_features = torch.split(image_features, split_sizes, dim=0)
- image_features = [x.flatten(0, 1) for x in image_features]
- else:
- image_features = self.encode_images(images)
- image_features = [x for x in image_features]
- return image_features
+ split_sizes = [x.shape[0] for x in pixel_values]
+ pixel_values = torch.cat(pixel_values, dim=0)
+ pixel_values = pixel_values.to(device=self.vision_tower.device,
+ dtype=torch.float16)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ if pixel_values.ndim == 5:
+ feats = self.encode_images(pixel_values)
+ feats = torch.split(feats, split_sizes, dim=0)
+ feats = [x.flatten(0, 1) for x in feats]
+ else:
+ feats = self.encode_images(pixel_values)
+ feats = [x for x in feats]
+ outputs.extend(feats)
+ messages.append(dict(role='forward', content=outputs))
+ return messages
diff --git a/lmdeploy/vl/model/llava.py b/lmdeploy/vl/model/llava.py
index 0b18f460cd..7ad919bef7 100644
--- a/lmdeploy/vl/model/llava.py
+++ b/lmdeploy/vl/model/llava.py
@@ -1,16 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
-# Modified from
-# https://github.com/haotian-liu/LLaVA.git
+# Modified from https://github.com/haotian-liu/LLaVA.git
+import ast
+import math
import warnings
from contextlib import contextmanager
-from typing import List, Union
+from typing import Dict, List
import torch
-from PIL.Image import Image
+from PIL import Image
from transformers import AutoConfig, AutoModelForCausalLM
from lmdeploy.utils import get_logger
-from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
+from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel
from lmdeploy.vl.model.utils import disable_logging, rewrite_ctx
logger = get_logger('lmdeploy')
@@ -23,16 +24,14 @@ def check_llava_install():
except ImportError:
raise ImportError(
'To use LlavaVLModel, please install llava by '
- 'pip install git+https://github.com/haotian-liu/LLaVA.git --no-deps' # noqa: E501
+ '`pip install git+https://github.com/haotian-liu/LLaVA.git --no-deps`' # noqa: E501
)
def _clip_vision_tower_load_model(self, **kwargs):
logger.info(f'CLIPVisionTower.load_model: {self.vision_tower_name}')
- from transformers import (CLIPImageProcessor, CLIPVisionConfig,
- CLIPVisionModel)
- self.image_processor = CLIPImageProcessor.from_pretrained(
- self.vision_tower_name)
+ from transformers import CLIPVisionConfig, CLIPVisionModel
+
config = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel._from_config(config=config)
self.vision_tower.requires_grad_(False)
@@ -53,8 +52,166 @@ def init_llava_vision_tower(config):
yield
+def select_best_resolution(original_size, possible_resolutions):
+ """Selects the best resolution from a list of possible resolutions based on
+ the original size.
+
+ Args:
+ original_size (tuple): The original size of the image in the format (width, height).
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
+
+ Returns:
+ tuple: The best fit resolution in the format (width, height).
+ """ # noqa
+ original_width, original_height = original_size
+ best_fit = None
+ max_effective_resolution = 0
+ min_wasted_resolution = float('inf')
+
+ for width, height in possible_resolutions:
+ scale = min(width / original_width, height / original_height)
+ downscaled_width, downscaled_height = int(original_width * scale), int(
+ original_height * scale)
+ effective_resolution = min(downscaled_width * downscaled_height,
+ original_width * original_height)
+ wasted_resolution = (width * height) - effective_resolution
+
+ if effective_resolution > max_effective_resolution or (
+ effective_resolution == max_effective_resolution
+ and wasted_resolution < min_wasted_resolution):
+ max_effective_resolution = effective_resolution
+ min_wasted_resolution = wasted_resolution
+ best_fit = (width, height)
+
+ return best_fit
+
+
+def resize_and_pad_image(image, target_resolution):
+ """Resize and pad an image to a target resolution while maintaining aspect
+ ratio.
+
+ Args:
+ image (PIL.Image.Image): The input image.
+ target_resolution (tuple): The target resolution (width, height) of the image.
+
+ Returns:
+ PIL.Image.Image: The resized and padded image.
+ """ # noqa
+ original_width, original_height = image.size
+ target_width, target_height = target_resolution
+
+ scale_w = target_width / original_width
+ scale_h = target_height / original_height
+
+ if scale_w < scale_h:
+ new_width = target_width
+ new_height = min(math.ceil(original_height * scale_w), target_height)
+ else:
+ new_height = target_height
+ new_width = min(math.ceil(original_width * scale_h), target_width)
+
+ # Resize the image
+ resized_image = image.resize((new_width, new_height))
+
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
+ paste_x = (target_width - new_width) // 2
+ paste_y = (target_height - new_height) // 2
+ new_image.paste(resized_image, (paste_x, paste_y))
+
+ return new_image
+
+
+def divide_to_patches(image, patch_size):
+ """Divides an image into patches of a specified size.
+
+ Args:
+ image (PIL.Image.Image): The input image.
+ patch_size (int): The size of each patch.
+
+ Returns:
+ list: A list of PIL.Image.Image objects representing the patches.
+ """
+ patches = []
+ width, height = image.size
+ for i in range(0, height, patch_size):
+ for j in range(0, width, patch_size):
+ box = (j, i, j + patch_size, i + patch_size)
+ patch = image.crop(box)
+ patches.append(patch)
+
+ return patches
+
+
+def process_anyres_image(image, processor, grid_pinpoints):
+ """Process an image with variable resolutions.
+
+ Args:
+ image (PIL.Image.Image): The input image to be processed.
+ processor: The image processor object.
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
+
+ Returns:
+ torch.Tensor: A tensor containing the processed image patches.
+ """ # noqa
+ if type(grid_pinpoints) is list:
+ possible_resolutions = grid_pinpoints
+ else:
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
+ image_padded = resize_and_pad_image(image, best_resolution)
+
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
+
+ image_original_resize = image.resize(
+ (processor.size['shortest_edge'], processor.size['shortest_edge']))
+
+ image_patches = [image_original_resize] + patches
+ image_patches = [
+ processor.preprocess(image_patch,
+ return_tensors='pt')['pixel_values'][0]
+ for image_patch in image_patches
+ ]
+ return torch.stack(image_patches, dim=0)
+
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+def process_images(images, image_processor, model_cfg):
+ image_aspect_ratio = getattr(model_cfg, 'image_aspect_ratio', None)
+ new_images = []
+ if image_aspect_ratio == 'pad':
+ for image in images:
+ image = expand2square(
+ image, tuple(int(x * 255) for x in image_processor.image_mean))
+ image = image_processor.preprocess(
+ image, return_tensors='pt')['pixel_values'][0]
+ new_images.append(image)
+ elif image_aspect_ratio == 'anyres':
+ for image in images:
+ image = process_anyres_image(image, image_processor,
+ model_cfg.image_grid_pinpoints)
+ new_images.append(image)
+ else:
+ return image_processor(images, return_tensors='pt')['pixel_values']
+ if all(x.shape == new_images[0].shape for x in new_images):
+ new_images = torch.stack(new_images, dim=0)
+ return new_images
+
+
@VISION_MODELS.register_module()
-class LlavaVisionModel(VisonModel):
+class LlavaVisionModel(LlavaHfVisionModel):
"""Llava visual model."""
@classmethod
@@ -73,9 +230,20 @@ def match(cls, config: AutoConfig):
return True
return False
+ def build_preprocessor(self):
+ from transformers import CLIPImageProcessor
+ self.image_processor = CLIPImageProcessor.from_pretrained(
+ self.hf_config.mm_vision_tower)
+ config = AutoConfig.from_pretrained(self.hf_config.mm_vision_tower)
+ image_size = config.vision_config.image_size
+ patch_size = config.vision_config.patch_size
+ self.n_token_per_image = (image_size // patch_size)**2
+ if self.hf_config.mm_vision_select_feature == 'cls_patch':
+ self.n_token_per_image += 1
+
def build_model(self):
- """build model & load weights."""
- # check llava install
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
check_llava_install()
self.arch = self.hf_config.architectures[0]
@@ -104,15 +272,13 @@ def build_model(self):
model = AutoModelForCausalLM.from_config(self.config,
trust_remote_code=True)
+ self.vl_model = model
if not self.with_llm:
# remove the LLM part from llava model.
- # Instead, Load the LLM part to turbomind engine
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm
- else:
- self.vl_model = model
# init empty vision_tower, the embedding layer in CLIPVisionModel
# can't init right under init_empty_weights
@@ -143,101 +309,113 @@ def encode_images(self, images: torch.Tensor) -> torch.Tensor:
image_features = self.mm_projector(image_features)
return image_features
- def preprocess(
- self,
- images: List[Image]) -> Union[torch.Tensor, List[torch.Tensor]]:
- """preprocess."""
- # TODO: gpu processor
- from llava.mm_utils import process_images
- images = [x.convert('RGB') for x in images]
- image_processor = self.vision_tower.image_processor
- outputs = process_images(images, image_processor, self.config)
- return outputs
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to `super().preprocess() for spec."""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ pixel_values = process_images([image], self.image_processor,
+ self.config)
+ outputs.append(
+ dict(pixel_values=pixel_values,
+ image_size=image.size,
+ image_tokens=self.n_token_per_image,
+ image_token_id=0))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
@torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
from llava.model.llava_arch import (get_anyres_image_grid_shape,
unpad_image)
- image_sizes = [x.size for x in images]
- images = self.preprocess(images)
- if isinstance(images, list):
- images = [
- x.to(device=self.vision_tower.device, dtype=torch.float16)
- for x in images
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ image_sizes = [
+ x['image_size'] for x in inputs[idx:idx + max_batch_size]
]
- else:
- images = images.to(device=self.vision_tower.device,
- dtype=torch.float16)
- if type(images) is list or images.ndim == 5:
- if type(images) is list:
- images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
- concat_images = torch.cat([image for image in images], dim=0)
- image_features = self.encode_images(concat_images)
- split_sizes = [image.shape[0] for image in images]
- image_features = torch.split(image_features, split_sizes, dim=0)
- mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type',
- 'flat')
- image_aspect_ratio = getattr(self.config, 'image_aspect_ratio',
- 'square')
- if mm_patch_merge_type == 'flat':
- image_features = [x.flatten(0, 1) for x in image_features]
- elif mm_patch_merge_type.startswith('spatial'):
- new_image_features = []
- for image_idx, image_feature in enumerate(image_features):
- if image_feature.shape[0] > 1:
- base_image_feature = image_feature[0]
- image_feature = image_feature[1:]
- height = width = self.vision_tower.num_patches_per_side
- assert height * width == base_image_feature.shape[0]
- if image_aspect_ratio == 'anyres':
- num_patch_width, num_patch_height = \
- get_anyres_image_grid_shape(
- image_sizes[image_idx],
- self.config.image_grid_pinpoints,
- self.vision_tower.config.image_size)
- image_feature = image_feature.view(
- num_patch_height, num_patch_width, height,
- width, -1)
- else:
- raise NotImplementedError
- if 'unpad' in mm_patch_merge_type:
- image_feature = image_feature.permute(
- 4, 0, 2, 1, 3).contiguous()
- image_feature = image_feature.flatten(1,
- 2).flatten(
- 2, 3)
- image_feature = unpad_image(
- image_feature, image_sizes[image_idx])
- image_feature = torch.cat((
- image_feature,
- self.model.image_newline[:, None, None].expand(
- *image_feature.shape[:-1], 1).to(
- image_feature.device)),
- dim=-1)
- image_feature = image_feature.flatten(1,
- 2).transpose(
- 0, 1)
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ if pixel_values[0].ndim == 5:
+ split_sizes = [x.shape[1] for x in pixel_values]
+ pixel_values = torch.cat([x for x in pixel_values], dim=1)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ pixel_values = pixel_values.squeeze(0)
+ pixel_values = pixel_values.to(device=self.vision_tower.device,
+ dtype=torch.float16)
+ feats = self.encode_images(pixel_values)
+ feats = torch.split(feats, split_sizes, dim=0)
+ mm_patch_merge_type = getattr(self.config,
+ 'mm_patch_merge_type', 'flat')
+ image_aspect_ratio = getattr(self.config, 'image_aspect_ratio',
+ 'square')
+ if mm_patch_merge_type == 'flat':
+ outputs.expand([x.flatten(0, 1) for x in feats])
+ elif mm_patch_merge_type.startswith('spatial'):
+ for img_idx, feat in enumerate(feats):
+ if feat.shape[0] > 1:
+ base_feat = feat[0]
+ feat = feat[1:]
+ height = self.vision_tower.num_patches_per_side
+ width = self.vision_tower.num_patches_per_side
+ assert height * width == base_feat.shape[0]
+ if image_aspect_ratio == 'anyres':
+ num_patch_width, num_patch_height = \
+ get_anyres_image_grid_shape(
+ image_sizes[img_idx],
+ self.config.image_grid_pinpoints,
+ self.vision_tower.config.image_size)
+ feat = feat.view(num_patch_height,
+ num_patch_width, height,
+ width, -1)
+ else:
+ raise NotImplementedError
+ if 'unpad' in mm_patch_merge_type:
+ feat = feat.permute(4, 0, 2, 1, 3).contiguous()
+ feat = feat.flatten(1, 2).flatten(2, 3)
+ feat = unpad_image(feat, image_sizes[img_idx])
+ feat = torch.cat(
+ (feat, self.model.
+ image_newline[:, None, None].expand(
+ *feat.shape[:-1], 1).to(feat.device)),
+ dim=-1)
+ feat = feat.flatten(1, 2).transpose(0, 1)
+ else:
+ feat = feat.permute(0, 2, 1, 3, 4).contiguous()
+ feat = feat.flatten(0, 3)
+ feat = torch.cat((base_feat, feat), dim=0)
else:
- image_feature = image_feature.permute(
- 0, 2, 1, 3, 4).contiguous()
- image_feature = image_feature.flatten(0, 3)
- image_feature = torch.cat(
- (base_image_feature, image_feature), dim=0)
- else:
- image_feature = image_feature[0]
- if 'unpad' in mm_patch_merge_type:
- image_feature = torch.cat(
- (image_feature,
- self.model.image_newline[None].to(
- image_feature.device)),
- dim=0)
- new_image_features.append(image_feature)
- image_features = new_image_features
+ feat = feat[0]
+ if 'unpad' in mm_patch_merge_type:
+ feat = torch.cat(
+ (feat, self.model.image_newline[None].to(
+ feat.device)),
+ dim=0)
+ outputs.append(feat)
+ else:
+ raise ValueError('Unexpected mm_patch_merge_type: '
+ f'{self.config.mm_patch_merge_type}')
else:
- raise ValueError('Unexpected mm_patch_merge_type: '
- f'{self.config.mm_patch_merge_type}')
- else:
- image_features = self.encode_images(images)
- image_features = [x for x in image_features]
- return image_features
+ pixel_values = torch.cat(pixel_values, dim=0)
+ pixel_values = pixel_values.to(device=self.vision_tower.device,
+ dtype=torch.float16)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ feats = self.encode_images(pixel_values)
+ outputs.extend([x for x in feats])
+ messages.append(dict(role='forward', content=outputs))
+ return messages
diff --git a/lmdeploy/vl/model/llava_hf.py b/lmdeploy/vl/model/llava_hf.py
index 31be101ae8..c4e3c90bfb 100644
--- a/lmdeploy/vl/model/llava_hf.py
+++ b/lmdeploy/vl/model/llava_hf.py
@@ -1,15 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
-
import warnings
-from typing import List
+from typing import Dict, List
import torch
-from PIL.Image import Image
from transformers import AutoProcessor
+from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import disable_logging
+logger = get_logger('lmdeploy')
+
@VISION_MODELS.register_module()
class LlavaHfVisionModel(VisonModel):
@@ -17,19 +18,31 @@ class LlavaHfVisionModel(VisonModel):
_arch = 'LlavaForConditionalGeneration'
+ def build_preprocessor(self):
+ processor = AutoProcessor.from_pretrained(self.model_path,
+ trust_remote_code=True)
+ if hasattr(processor, 'tokenizer'):
+ del processor.tokenizer
+ processor.prtokenizer = None
+ self.processor = processor.image_processor
+ image_size = self.hf_config.vision_config.image_size
+ patch_size = self.hf_config.vision_config.patch_size
+ self.n_token_per_image = (image_size // patch_size)**2
+ if self.hf_config.vision_feature_select_strategy == 'full':
+ self.n_token_per_image += 1
+
def build_model(self):
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
with init_empty_weights(), warnings.catch_warnings():
warnings.simplefilter('ignore')
from transformers import LlavaForConditionalGeneration
model = LlavaForConditionalGeneration._from_config(self.hf_config)
+ self.vl_model = model
if not self.with_llm:
del model.language_model
- for key in ['language_model']:
- setattr(model, key, None)
- else:
- self.vl_model = model
# fix for llava-hf/llava-interleave-qwen-7b-hf
setattr(model.config, 'tie_word_embeddings', False)
@@ -45,35 +58,97 @@ def build_model(self):
dtype=torch.half)
model.eval()
self.model = model
- # processor
- processor = AutoProcessor.from_pretrained(self.model_path,
- trust_remote_code=True)
- if hasattr(processor, 'tokenizer'):
- del processor.tokenizer
- processor.prtokenizer = None
- self.processor = processor.image_processor
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refers to `super.preprocess() for spec."""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ pixel_values = self.processor(
+ image, return_tensors='pt',
+ input_data_format='channels_last').pixel_values
+ outputs.append(
+ dict(pixel_values=pixel_values,
+ image_size=image.size,
+ image_tokens=self.n_token_per_image,
+ image_token_id=0))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
@torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- pixel_values = self.processor(
- images, return_tensors='pt',
- input_data_format='channels_last')['pixel_values']
- pixel_values = pixel_values.to(device=self.model.device,
- dtype=self.model.dtype)
- image_outputs = self.model.vision_tower.forward(
- pixel_values, output_hidden_states=True)
- image_features = image_outputs.hidden_states[
- self.hf_config.vision_feature_layer]
- if self.hf_config.vision_feature_select_strategy == 'default':
- image_features = image_features[:, 1:]
- elif self.hf_config.vision_feature_select_strategy == 'full':
- image_features = image_features
- else:
- raise ValueError(
- 'Unexpected select feature strategy: '
- f'{self.hf_config.vision_feature_select_strategy}')
- image_features = self.model.multi_modal_projector(image_features)
- outputs = torch.split(image_features, 1, dim=0)
- outputs = [x.squeeze() for x in outputs]
- return outputs
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ pixel_values = torch.cat(pixel_values, dim=0)
+ pixel_values = pixel_values.to(device=self.model.device,
+ dtype=self.model.dtype)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ image_outputs = self.model.vision_tower.forward(
+ pixel_values, output_hidden_states=True)
+ image_features = image_outputs.hidden_states[
+ self.hf_config.vision_feature_layer]
+ if self.hf_config.vision_feature_select_strategy == 'default':
+ image_features = image_features[:, 1:]
+ elif self.hf_config.vision_feature_select_strategy == 'full':
+ image_features = image_features
+ else:
+ raise ValueError(
+ 'Unexpected select feature strategy: '
+ f'{self.hf_config.vision_feature_select_strategy}')
+ image_features = self.model.multi_modal_projector(image_features)
+ image_features = torch.split(image_features, 1, dim=0)
+ outputs.extend([x.squeeze() for x in image_features])
+ messages.append(dict(role='forward', content=outputs))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
+ continue
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ content = [
+ item['text'] for item in message['content']
+ if item['type'] == 'text'
+ ]
+ prompt = (IMAGE_TOKEN + '\n') * n_images + content[0]
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/llava_next.py b/lmdeploy/vl/model/llava_next.py
index 9223ebea4f..d355a48d60 100644
--- a/lmdeploy/vl/model/llava_next.py
+++ b/lmdeploy/vl/model/llava_next.py
@@ -1,46 +1,51 @@
# Copyright (c) OpenMMLab. All rights reserved.
-
+import itertools
import warnings
-from typing import List
+from typing import Dict, List
import torch
-from PIL.Image import Image
-from transformers import AutoProcessor
-from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
+from lmdeploy.utils import get_logger
+from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel
from lmdeploy.vl.model.utils import disable_logging
+logger = get_logger('lmdeploy')
+
@VISION_MODELS.register_module()
-class LlavaNextVisionModel(VisonModel):
+class LlavaNextVisionModel(LlavaHfVisionModel):
"""Llava hf vision model."""
_arch = 'LlavaNextForConditionalGeneration'
- def build_model(self):
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
- from accelerate.utils import get_balanced_memory, infer_auto_device_map
-
+ def build_preprocessor(self):
+ super().build_preprocessor()
+ # build the model with empty weights. The model will be used in
+ # `preprocess` to get the image token number
+ from accelerate import init_empty_weights
with init_empty_weights(), warnings.catch_warnings():
warnings.simplefilter('ignore')
from transformers import LlavaNextForConditionalGeneration
- model = LlavaNextForConditionalGeneration._from_config(
+ self.model = LlavaNextForConditionalGeneration._from_config(
self.hf_config)
+ self.vl_model = self.model
if not self.with_llm:
- del model.language_model
- for key in ['language_model']:
- setattr(model, key, None)
- else:
- self.vl_model = model
+ del self.model.language_model
+
+ def build_model(self):
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
+ from accelerate import load_checkpoint_and_dispatch
+ from accelerate.utils import get_balanced_memory, infer_auto_device_map
no_split_module_classes = ['CLIPEncoderLayer']
max_memory = get_balanced_memory(
- model,
+ self.model,
max_memory=self.max_memory,
dtype=torch.half,
no_split_module_classes=no_split_module_classes)
device_map = infer_auto_device_map(
- model,
+ self.model,
no_split_module_classes=no_split_module_classes,
max_memory=max_memory,
dtype=torch.half)
@@ -55,75 +60,128 @@ def build_model(self):
with disable_logging():
load_checkpoint_and_dispatch(
- model=model,
+ model=self.model,
checkpoint=self.model_path,
device_map=device_map if not self.with_llm else {'': 'cpu'},
no_split_module_classes=no_split_module_classes,
dtype=torch.half)
- model.eval()
- self.model = model
- # processor
- processor = AutoProcessor.from_pretrained(self.model_path,
- trust_remote_code=True)
- if hasattr(processor, 'tokenizer'):
- del processor.tokenizer
- processor.prtokenizer = None
- self.processor = processor.image_processor
+ self.model.eval()
- @torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refers to the spec of `super.preprocess()"""
from transformers.models.llava_next.modeling_llava_next import \
image_size_to_num_patches
- """forward."""
- processed_inputs = self.processor(images,
- return_tensors='pt',
- input_data_format='channels_last')
- pixel_values = processed_inputs['pixel_values'].to(
- device=self.model.device, dtype=self.model.dtype)
- image_sizes = processed_inputs['image_sizes'].to(
- device=self.model.device, dtype=self.model.dtype)
- # ! infer image_num_patches from image_sizes
- image_num_patches = [
- image_size_to_num_patches(
- image_size=imsize,
- grid_pinpoints=self.hf_config.image_grid_pinpoints,
- patch_size=self.hf_config.vision_config.image_size,
- ) for imsize in image_sizes
- ]
- # figure out if pixel_values is concatenated or stacked
- if pixel_values.dim() == 5:
- # stacking when input is
- # (batch_size, num_patches, num_channels, height, width)
- _pixel_values_list = [
- pix_val[:num_patch]
- for pix_val, num_patch in zip(pixel_values, image_num_patches)
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ result = self.processor(image,
+ return_tensors='pt',
+ input_data_format='channels_last')
+ # ! infer image_num_patches from image_sizes
+ image_num_patches = [
+ image_size_to_num_patches(
+ image_size=imsize,
+ grid_pinpoints=self.hf_config.image_grid_pinpoints,
+ patch_size=self.hf_config.vision_config.image_size,
+ ) for imsize in result['image_sizes']
]
- pixel_values = torch.cat(_pixel_values_list, dim=0)
- elif pixel_values.dim() != 4:
- # otherwise has to be stacked from list of
- # (num_patches, num_channels, height, width)
- raise ValueError(f'pixel_values of shape {pixel_values.shape}, '
- 'expect to be of 4 or 5 dimensions')
- image_outputs = self.model.vision_tower.forward(
- pixel_values, output_hidden_states=True)
- image_features = image_outputs.hidden_states[
- self.hf_config.vision_feature_layer]
- if self.hf_config.vision_feature_select_strategy == 'default':
- image_features = image_features[:, 1:]
- elif self.hf_config.vision_feature_select_strategy == 'full':
- image_features = image_features
- else:
- raise ValueError(
- 'Unexpected select feature strategy: '
- f'{self.hf_config.vision_feature_select_strategy}')
- image_features = self.model.multi_modal_projector(image_features)
- image_features = torch.split(image_features, image_num_patches, dim=0)
- image_features, feature_lens = self.model.pack_image_features(
- image_features,
- image_sizes,
- image_newline=self.model.image_newline,
- )
- outputs = torch.split(image_features,
- feature_lens.cpu().numpy().tolist(),
- dim=0)
- return outputs
+
+ hidden_size = self.hf_config.text_config.hidden_size
+ fake_image_features = torch.zeros(
+ [image_num_patches[0], self.n_token_per_image, hidden_size])
+ image_sizes = result['image_sizes']
+ image_newline = torch.randn(self.hf_config.text_config.hidden_size)
+ strategy = self.hf_config.vision_feature_select_strategy
+ _, image_tokens = self.model.pack_image_features(
+ [fake_image_features],
+ image_sizes,
+ vision_feature_select_strategy=strategy,
+ image_newline=image_newline)
+ result.update(
+ dict(image_size=image.size,
+ image_patches=image_num_patches,
+ image_tokens=image_tokens,
+ image_token_id=0))
+ outputs.append(result)
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
+
+ @torch.no_grad()
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ pixel_values = [
+ x['pixel_values'].to(device=self.model.device,
+ dtype=self.model.dtype)
+ for x in inputs[idx:idx + max_batch_size]
+ ]
+ pixel_values = torch.cat(pixel_values, dim=0)
+ image_sizes = [
+ x['image_sizes'].to(device=self.model.device,
+ dtype=self.model.dtype)
+ for x in inputs[idx:idx + max_batch_size]
+ ]
+ image_sizes = torch.cat(image_sizes, dim=0)
+ image_num_patches = [
+ x['num_patch'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ image_num_patches = list(itertools.chain(*image_num_patches))
+ # figure out if pixel_values is concatenated or stacked
+ if pixel_values.dim() == 5:
+ # stacking when input is
+ # (batch_size, num_patches, num_channels, height, width)
+ _pixel_values_list = [
+ pix_val[:num_patch] for pix_val, num_patch in zip(
+ pixel_values, image_num_patches)
+ ]
+ pixel_values = torch.cat(_pixel_values_list, dim=0)
+ elif pixel_values.dim() != 4:
+ # otherwise has to be stacked from list of
+ # (num_patches, num_channels, height, width)
+ raise ValueError(
+ f'pixel_values of shape {pixel_values.shape}, '
+ 'expect to be of 4 or 5 dimensions')
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ image_outputs = self.model.vision_tower.forward(
+ pixel_values, output_hidden_states=True)
+ image_features = image_outputs.hidden_states[
+ self.hf_config.vision_feature_layer]
+ strategy = self.hf_config.vision_feature_select_strategy
+ if strategy == 'default':
+ image_features = image_features[:, 1:]
+ elif strategy == 'full':
+ image_features = image_features
+ else:
+ raise ValueError('Unexpected select feature strategy: '
+ f'{strategy}')
+ image_features = self.model.multi_modal_projector(image_features)
+ image_features = torch.split(image_features,
+ image_num_patches,
+ dim=0)
+ image_features, feature_lens = self.model.pack_image_features(
+ image_features,
+ image_sizes,
+ vision_feature_select_strategy=strategy,
+ image_newline=self.model.image_newline,
+ )
+ image_features = torch.split(image_features,
+ feature_lens.cpu().numpy().tolist(),
+ dim=0)
+ outputs.extend(image_features)
+ messages.append(dict(role='forward', content=outputs))
+ return messages
diff --git a/lmdeploy/vl/model/mini_gemeni.py b/lmdeploy/vl/model/mini_gemeni.py
index 0565daeba5..eca70aca51 100644
--- a/lmdeploy/vl/model/mini_gemeni.py
+++ b/lmdeploy/vl/model/mini_gemeni.py
@@ -3,16 +3,18 @@
import os.path as osp
import warnings
from contextlib import contextmanager
-from typing import List
+from typing import Dict, List
import torch
-from PIL.Image import Image
+from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import (add_device_hook, disable_logging,
disable_transformers_logging,
hack_import_with)
+logger = get_logger('lmdeploy')
+
def check_mini_gemini_install():
"""check mini gemini install."""
@@ -22,8 +24,8 @@ def check_mini_gemini_install():
except ImportError:
raise ImportError(
'To use MiniGeminiVisionModel, please install minigemini by '
- 'pip install git+https://github.com/dvlab-research/MGM.git'
- ' --no-deps')
+ '`pip install git+https://github.com/dvlab-research/MGM.git'
+ ' --no-deps`')
def _build_vision_tower(vision_tower_cfg, **kwargs):
@@ -169,7 +171,15 @@ class MiniGeminiVisionModel(VisonModel):
_arch = ['MiniGeminiLlamaForCausalLM', 'MGMLlamaForCausalLM']
+ def build_preprocessor(self):
+ # pytorch engine will not support mini-gemini. Therefore, in order to
+ # reuse the previous code as much as possible, we do not extract image
+ # preprocessor from `build_model` function.
+ pass
+
def build_model(self):
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
check_mini_gemini_install()
# empty init
from accelerate import init_empty_weights
@@ -193,13 +203,12 @@ def build_model(self):
vision_tower.load_model()
vision_tower_aux = model.get_vision_tower_aux()
vision_tower_aux.load_model()
+ self.vl_model = model
if not self.with_llm:
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm
- else:
- self.vl_model = model
from accelerate.utils import get_balanced_memory, infer_auto_device_map
max_memory = get_balanced_memory(
@@ -246,11 +255,35 @@ def build_model(self):
self.image_processor = image_processor
self.process_images = process_images
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ return messages
+
@torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- outputs = [x.convert('RGB') for x in images]
- image_tensor = self.process_images(outputs, self.image_processor,
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ images = []
+ for message in messages:
+ if not isinstance(message['content'], List):
+ continue
+ _ = [
+ x['image'] for x in message['content'] if x['type'] == 'image'
+ ]
+ assert len(_) == 1, f'MiniGeminiLlama accepts ONE input ' \
+ f'image, but got {len(images)} images'
+ images.extend(_)
+
+ image_tensor = self.process_images(images, self.image_processor,
self.model.config)
image_grid = getattr(self.model.config, 'image_grid', 1)
if hasattr(self.model.config, 'image_size_aux'):
@@ -301,15 +334,47 @@ def forward(self, images: List[Image]) -> List[torch.Tensor]:
image.to(self.model.device, dtype=torch.float16)
for image in image_tensor_aux
]
+ logger.info(f'vision forward bs: {len(image_tensor)}')
else:
image_tensor = image_tensor.to(self.model.device,
dtype=torch.float16)
image_tensor_aux = image_tensor_aux.to(self.model.device,
dtype=torch.float16)
-
+ logger.info(f'vision forward shape: {image_tensor.shape}')
images_embeds = self.model.encode_images(image_tensor,
image_tensor_aux)
outputs = torch.split(images_embeds, 1, dim=0)
outputs = [x.squeeze() for x in outputs]
- return outputs
+ messages.append(dict(role='forward', cotent=outputs))
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
+ continue
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ content = [
+ item['text'] for item in message['content']
+ if item['type'] == 'text'
+ ]
+ prompt = (IMAGE_TOKEN + '\n') * n_images + content[0]
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ assert 0, 'cogvlm is not supported by pytorch engine'
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/minicpmv.py b/lmdeploy/vl/model/minicpmv.py
index 4e30190c1d..6b0c5f1508 100644
--- a/lmdeploy/vl/model/minicpmv.py
+++ b/lmdeploy/vl/model/minicpmv.py
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import itertools
import warnings
from typing import Dict, List
import torch
from PIL.Image import Image
-from transformers import AutoModelForCausalLM
+from transformers import AutoConfig, AutoModelForCausalLM
from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
@@ -19,8 +20,33 @@ class MiniCPMVModel(VisonModel):
_arch = 'MiniCPMV'
+ def __init__(self,
+ model_path: str,
+ with_llm: bool = False,
+ max_memory: Dict[int, int] = None,
+ hf_config: AutoConfig = None,
+ backend: str = ''):
+ super().__init__(model_path, with_llm, max_memory, hf_config, backend)
+ if not hasattr(self.hf_config, 'version'):
+ raise ValueError('Can not find `version` in config.json. '
+ 'Please checkout the latest model')
+ version = str(self.hf_config.version)
+ if version not in ['2.5', '2.6']:
+ raise ValueError(
+ f'Only support v2.5 and v2.6, but got version {version}')
+ self.version = version
+
+ def build_preprocessor(self):
+ from transformers import AutoProcessor
+ self.processor = AutoProcessor.from_pretrained(self.model_path,
+ trust_remote_code=True)
+ self.image_processor = self.processor.image_processor
+ self._preprocess_func = (self._preprocess_v2_5 if self.version == '2.5'
+ else self._preprocess_v2_6)
+
def build_model(self):
- """build model & load weights."""
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights
with init_empty_weights(), warnings.catch_warnings():
warnings.simplefilter('ignore')
@@ -29,10 +55,9 @@ def build_model(self):
config.quantization_config = {} # disable vision part quantization
model = AutoModelForCausalLM.from_config(config,
trust_remote_code=True)
+ self.vl_model = model
if not self.with_llm:
del model.llm
- else:
- self.vl_model = model
from accelerate import load_checkpoint_and_dispatch
with disable_logging():
@@ -50,46 +75,11 @@ def build_model(self):
device=model.resampler.proj.device)
self.config = config
self.model = model.eval()
- self.init_forward_func()
-
- def init_forward_func(self):
- if not hasattr(self.config, 'version'):
- msg = 'LMDeploy only support `MiniCPM-V-2_6` and '\
- '`MiniCPM-Llama3-V-2_5`.\nCan not find `version` in config, ' \
- 'please consider update the huggingface model.'
- logger.warn(msg)
-
- self._forward_func = self._forward_v2_5
- if hasattr(self.config, 'version'):
- version = str(self.config.version)
- if version == '2.6':
- self._forward_func = self._forward_v2_6
-
- if self._forward_func == self._forward_v2_5:
- logger.info('using _forward_v2_5')
- if not hasattr(self.model, 'slice_image'):
- # adapt new code commit 287e3f85 (MiniCPM-Llama3-V-2_5)
- from transformers import AutoProcessor
- processor = AutoProcessor.from_pretrained(
- self.model_path, trust_remote_code=True)
- self.model.slice_image = processor.image_processor.slice_image
-
- def _reshape_by_patch(x):
- out = x.cpu().numpy()
- out = processor.image_processor.reshape_by_patch(out)
- return torch.from_numpy(out).to(device=x.device)
-
- self.model.reshape_by_patch = _reshape_by_patch
-
- if self._forward_func == self._forward_v2_6:
- logger.info('using _forward_v2_6')
- from transformers import AutoProcessor
- self.model.processor = AutoProcessor.from_pretrained(
- self.model_path, trust_remote_code=True)
def _get_slice_image(self, image: Image):
slice_images = []
- source_image, patches, best_grid = self.model.slice_image(image)
+ source_image, patches, best_grid = self.image_processor.slice_image(
+ image)
slice_images.append(source_image)
if len(patches) > 0:
for i in range(len(patches)):
@@ -103,114 +93,198 @@ def _reshape_by_patch(self, slice_images):
for slice_image in slice_images:
slice_image = self.model.transform(slice_image)
H, W = slice_image.shape[1:]
- patches.append(self.model.reshape_by_patch(slice_image))
+ slice_image = slice_image.numpy()
+ slice_image = self.image_processor.reshape_by_patch(slice_image)
+ slice_image = torch.from_numpy(slice_image)
+ patches.append(slice_image)
H //= self.config.patch_size
W //= self.config.patch_size
tgt_sizes.append(torch.Tensor([H, W]).type(torch.int32))
return patches, tgt_sizes
- def _forward_v2_5(self, images: List[Image], params: List[Dict] = None):
- """forward for MiniCPM-Llama3-V-2_5."""
- patches = []
- tgt_sizes = []
- best_grids = []
- num_patches = []
- for image in images:
- slice_images, best_grid = self._get_slice_image(image)
- _patches, _tgt_sizes = self._reshape_by_patch(slice_images)
- num_patches.append(len(_patches))
- patches.extend(_patches)
- tgt_sizes.extend(_tgt_sizes)
- best_grids.append(best_grid)
-
- patches = [
- x.to(dtype=torch.half, device=self.model.device) for x in patches
- ]
- patches = [x.flatten(end_dim=1).permute(1, 0) for x in patches]
- tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
- max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
- all_pixel_values = torch.nn.utils.rnn.pad_sequence(patches,
- batch_first=True,
- padding_value=0.0)
- B, L, _ = all_pixel_values.shape
- all_pixel_values = all_pixel_values.permute(0, 2,
- 1).reshape(B, 3, -1, L)
- patch_attn_mask = torch.zeros((B, 1, max_patches),
- dtype=torch.bool,
- device=self.model.device)
- for i in range(B):
- patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
- vision_embedding = self.model.vpm(
- all_pixel_values.type(torch.half),
- patch_attention_mask=patch_attn_mask).last_hidden_state
- vision_embedding = self.model.resampler(vision_embedding, tgt_sizes)
- vision_embedding = torch.split(vision_embedding, num_patches, 0)
+ def _preprocess_v2_5(self, image: Image, params: Dict = None) -> Dict:
+ """image preprocessing for MiniCPM-Llama3-V-2_5."""
+ slice_images, best_grid = self._get_slice_image(image)
+ # pixel_values, tgt_sizes are list of torch tensors
+ pixel_values, tgt_sizes = self._reshape_by_patch(slice_images)
+ num_patches = len(pixel_values)
+ return dict(
+ pixel_values=pixel_values, # a list
+ tgt_sizes=tgt_sizes, # a list
+ best_grid=best_grid,
+ num_patches=num_patches,
+ image_tokens=1,
+ image_token_id=0)
+
+ def _preprocess_v2_6(self, image: Image, params: Dict = None) -> Dict:
+ """image preprocessing for MiniCPM-V-2_6."""
+ max_slice_nums = self.image_processor.max_slice_nums
+ use_image_id = self.image_processor.use_image_id
+ max_slice_nums = params.get('max_slice_nums', max_slice_nums)
+ use_image_id = params.get('use_image_id', use_image_id)
+ outputs = self.image_processor(image, max_slice_nums=max_slice_nums)
+ pixel_values = outputs['pixel_values'][0]
+ num_patches = len(pixel_values)
+ pixel_values = [torch.as_tensor(x) for x in pixel_values]
+ tgt_sizes = outputs['tgt_sizes'][0]
+ tgt_sizes = [torch.as_tensor(x) for x in tgt_sizes]
+ grid = self.image_processor.get_sliced_grid(
+ image_size=image.size, max_slice_nums=max_slice_nums)
+ return dict(
+ pixel_values=pixel_values, # a list
+ tgt_sizes=tgt_sizes, # a list
+ best_grid=grid,
+ num_patches=num_patches,
+ image_tokens=1,
+ image_token_id=0,
+ use_image_id=use_image_id)
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to `super().preprocess() for spec."""
outputs = []
- for embeddings, grid in zip(vision_embedding, best_grids):
- embeddings = embeddings.cpu() # n x d x h
- outputs.append(dict(embeddings=embeddings, grid=grid))
+ for i, message in enumerate(messages):
+ if message['role'] != 'user' or not isinstance(
+ message['content'], List):
+ continue
+ for item in message['content']:
+ if item['type'] == 'image':
+ image = item['image'].convert('RGB')
+ params = {
+ k: v
+ for k, v in item.items() if k not in {'type', 'image'}
+ }
+ result = self._preprocess_func(image, params)
+ outputs.append(result)
+ messages[i].update(dict(preprocess=outputs))
+ return messages
- return outputs
+ @torch.no_grad()
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
- def _forward_v2_6(self, images: List[Image], params: List[Dict] = None):
- """forward for MiniCPM-V-2_6."""
- patches = []
- tgt_sizes = []
- best_grids = []
- num_patches = []
- max_slice_nums = self.model.processor.image_processor.max_slice_nums
- use_image_id = self.model.processor.image_processor.use_image_id
- for image, param in zip(images, params):
- max_slice_nums = param.get('max_slice_nums', max_slice_nums)
- use_image_id = param.get('use_image_id', use_image_id)
- outputs = self.model.processor.image_processor(
- image, max_slice_nums=max_slice_nums)
- patches.extend(outputs['pixel_values'][0])
- num_patches.append(len(outputs['pixel_values'][0]))
- tgt_sizes.extend(outputs['tgt_sizes'][0])
- grid = self.model.processor.image_processor.get_sliced_grid(
- image_size=image.size, max_slice_nums=max_slice_nums)
- best_grids.append(grid)
-
- patches = [
- torch.as_tensor(x).to(dtype=torch.half, device=self.model.device)
- for x in patches
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ # collect preprocess results into a list
+ inputs = []
+ inputs = [
+ x['preprocess'] for x in messages if 'preprocess' in x.keys()
]
- patches = [x.flatten(end_dim=1).permute(1, 0) for x in patches]
- tgt_sizes = [torch.as_tensor(x) for x in tgt_sizes]
- tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
- max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
- all_pixel_values = torch.nn.utils.rnn.pad_sequence(patches,
+ # flatten the list
+ inputs = list(itertools.chain(*inputs))
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ tgt_sizes = [
+ x['tgt_sizes'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ num_patches = [
+ x['num_patches'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ # flatten the list
+ tgt_sizes = list(itertools.chain(*tgt_sizes))
+ pixel_values = list(itertools.chain(*pixel_values))
+ pixel_values = [
+ x.to(dtype=torch.half, device=self.model.device)
+ for x in pixel_values
+ ]
+ pixel_values = [
+ x.flatten(end_dim=1).permute(1, 0) for x in pixel_values
+ ]
+ pixel_values = torch.nn.utils.rnn.pad_sequence(pixel_values,
batch_first=True,
padding_value=0.0)
- B, L, _ = all_pixel_values.shape
- all_pixel_values = all_pixel_values.permute(0, 2,
- 1).reshape(B, 3, -1, L)
- patch_attn_mask = torch.zeros((B, 1, max_patches),
- dtype=torch.bool,
- device=self.model.device)
- for i in range(B):
- patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
- vision_embedding = self.model.vpm(
- all_pixel_values.type(torch.half),
- patch_attention_mask=patch_attn_mask,
- tgt_sizes=tgt_sizes).last_hidden_state
- vision_embedding = self.model.resampler(vision_embedding, tgt_sizes)
- vision_embedding = torch.split(vision_embedding, num_patches, 0)
- outputs = []
- for embeddings, grid in zip(vision_embedding, best_grids):
- embeddings = embeddings.cpu() # n x d x h
- outputs.append(
- dict(embeddings=embeddings,
- grid=grid,
- use_image_id=use_image_id))
+ B, L, _ = pixel_values.shape
+ pixel_values = pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
+ tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
+ max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
+ patch_attn_mask = torch.zeros((B, 1, max_patches),
+ dtype=torch.bool,
+ device=self.model.device)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ if self.version == '2.5':
+ for j in range(B):
+ patch_attn_mask[j, :tgt_sizes[j][0] *
+ tgt_sizes[j][1]] = True
+ embeddings = self.model.vpm(
+ pixel_values.type(torch.half),
+ patch_attention_mask=patch_attn_mask).last_hidden_state
+ else:
+ for j in range(B):
+ patch_attn_mask[j, 0, :tgt_sizes[j][0] *
+ tgt_sizes[j][1]] = True
+ embeddings = self.model.vpm(
+ pixel_values.type(torch.half),
+ patch_attention_mask=patch_attn_mask,
+ tgt_sizes=tgt_sizes).last_hidden_state
- return outputs
+ embeddings = self.model.resampler(embeddings, tgt_sizes)
+ embeddings = torch.split(embeddings, num_patches, 0)
+ for embedding in embeddings:
+ embedding = embedding.split(1, dim=0)
+ outputs.extend([x.squeeze() for x in embedding])
+ messages.append(dict(role='forward', content=outputs))
+ return messages
- @torch.no_grad()
- def forward(self,
- images: List[Image],
- params: List[Dict] = None) -> List[torch.Tensor]:
- """forward."""
- images = [x.convert('RGB') for x in images]
- return self._forward_func(images, params)
+ def proc_messages(self, messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ idx = 0
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ if 'preprocess' not in message.keys():
+ continue
+ prompts = []
+ for x in message['preprocess']:
+ prompt = f'{IMAGE_TOKEN}'
+ if x.get('use_image_id', False):
+ prompt = f'{idx}' + prompt
+ idx += 1
+ grid = x['best_grid']
+ if grid is not None:
+ if self.version == '2.5':
+ slice = '\n'.join(
+ [f'{IMAGE_TOKEN}' * grid[0]] *
+ grid[1])
+ prompt = f'{prompt}{slice}\n'
+ elif self.version == '2.6':
+ slice = '\n'.join(
+ [f'{IMAGE_TOKEN}' * grid[0]] *
+ grid[1])
+ prompt = prompt + slice
+ prompt += '\n'
+ else:
+ prompt = (prompt +
+ '\n' if self.version == '2.6' else prompt)
+ prompts.append(prompt)
+ content = [
+ x['text'] for x in message['content'] if x['type'] == 'text'
+ ]
+ prompt = ''.join(prompts) + content[0]
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/mllama.py b/lmdeploy/vl/model/mllama.py
index db0a0e9cbf..0cae71cd6c 100644
--- a/lmdeploy/vl/model/mllama.py
+++ b/lmdeploy/vl/model/mllama.py
@@ -2,192 +2,10 @@
from typing import Dict, List
-import torch
-import torch.nn.functional as F
-from PIL.Image import Image
-from transformers.modeling_outputs import BaseModelOutput
-from transformers.models.mllama.modeling_mllama import MllamaPreTrainedModel
-
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
-from lmdeploy.vl.model.utils import disable_logging
-
-
-class MllamaVisionModelPatch(MllamaPreTrainedModel):
-
- def apply_class_embedding(self,
- hidden_state: torch.Tensor) -> torch.Tensor:
- batch_size, _, hidden_size = hidden_state.shape
- class_embedding = self.class_embedding.expand(batch_size, 1,
- hidden_size)
- class_embedding = class_embedding.to(hidden_state.device)
- hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
- return hidden_state
-
- def forward(
- self,
- pixel_values: torch.Tensor,
- aspect_ratio_ids: torch.Tensor,
- aspect_ratio_mask: torch.Tensor,
- output_attentions: bool = None,
- output_hidden_states: bool = None,
- return_dict: bool = None,
- ):
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa
- output_hidden_states = (output_hidden_states
- if output_hidden_states is not None else
- self.config.output_hidden_states)
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict # noqa
-
- batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape # noqa
-
- pixel_values = pixel_values.reshape(
- batch_size * num_concurrent_media * num_tiles, num_channels,
- height, width)
- aspect_ratio_ids = aspect_ratio_ids.reshape(
- batch_size * num_concurrent_media, -1)
-
- # Patch embedding
- patch_embeds = self.patch_embedding(
- pixel_values.to(self.dtype).to(self.device))
- hidden_state = patch_embeds.flatten(2).transpose(1, 2)
-
- # Tile embeddings
- _, num_patches, dim = hidden_state.shape
- hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
- num_tiles, -1, dim)
- hidden_state = self.pre_tile_positional_embedding(
- hidden_state, aspect_ratio_ids)
-
- # Add cls token
- hidden_state = hidden_state.reshape(
- batch_size * num_concurrent_media * num_tiles, num_patches, dim)
- hidden_state = self.apply_class_embedding(hidden_state)
- num_patches += 1
-
- # Position embeddings
- hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
- num_tiles, num_patches, dim)
- hidden_state = self.gated_positional_embedding(hidden_state,
- aspect_ratio_ids)
-
- hidden_state = self.layernorm_pre(hidden_state)
-
- # Compute the number of tokens to pad
- num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
- # Compute padding tuple for pad function
- padding = (
- 0, 0, 0, num_padding_patches
- ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
- # Pad the tensor
- hidden_state = F.pad(hidden_state, padding, mode='constant', value=0)
- slice_index = -num_padding_patches if num_padding_patches > 0 else None
-
- # Prepare attention mask
- attention_mask = aspect_ratio_mask.reshape(
- batch_size * num_concurrent_media, -1)
- from transformers.models.mllama.modeling_mllama import \
- _prepare_aspect_ratio_attention_mask
- attention_mask = _prepare_aspect_ratio_attention_mask(
- aspect_ratio_mask=attention_mask,
- num_patches=self.num_patches,
- target_length=hidden_state.shape[2],
- dtype=self.dtype,
- )
-
- # Apply encoder
- hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1,
- dim)
- output = self.transformer(
- hidden_state,
- attention_mask=attention_mask,
- output_hidden_states=True,
- output_attentions=output_attentions,
- )
- hidden_state = output[0]
-
- hidden_state = self.layernorm_post(hidden_state)
-
- # Apply global encoder
- hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
- num_tiles,
- num_patches + num_padding_patches,
- dim)
- hidden_state = self.post_tile_positional_embedding(
- hidden_state, aspect_ratio_ids)
- hidden_state = hidden_state.reshape(
- batch_size * num_concurrent_media,
- num_tiles * (num_patches + num_padding_patches), dim)
- global_output = self.global_transformer(
- hidden_state,
- attention_mask=attention_mask,
- output_hidden_states=output_hidden_states,
- output_attentions=output_attentions,
- )
- hidden_state = global_output[0]
-
- # Remove padding form hidden state
- hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
- num_tiles,
- num_patches + num_padding_patches,
- dim)
- hidden_state = hidden_state[:, :, :slice_index]
- hidden_state = hidden_state.reshape(batch_size, num_concurrent_media,
- num_tiles, num_patches, dim)
-
- # Collect intermediate layer outputs from encoder output
- all_intermediate_hidden_states = output[1]
- # rewrite to sync device during accelerate pipeline parallel
- device = hidden_state.device
- all_intermediate_hidden_states = [
- s.to(device) for s in all_intermediate_hidden_states
- ]
- intermediate_hidden_states = torch.stack(
- all_intermediate_hidden_states, dim=-1)
- intermediate_hidden_states = intermediate_hidden_states[
- ..., self.intermediate_layers_indices]
-
- # Remove padding from intermediate hidden states
- intermediate_hidden_states = intermediate_hidden_states.reshape(
- batch_size * num_concurrent_media, num_tiles,
- num_patches + num_padding_patches, -1)
- intermediate_hidden_states = intermediate_hidden_states[:, :, :
- slice_index]
- intermediate_hidden_states = intermediate_hidden_states.reshape(
- batch_size, num_concurrent_media, num_tiles, num_patches, -1)
-
- # Concatenate final hidden state and intermediate hidden states
- hidden_state = torch.cat([hidden_state, intermediate_hidden_states],
- dim=-1)
-
- if output_hidden_states:
- hidden_states = tuple(all_intermediate_hidden_states) + tuple(
- global_output[1])
- else:
- hidden_states = None
-
- if output_attentions:
- # global transformer in contrast to `self.transformer` doesn't
- # always return hidden states so we might go index out-of-range
- global_attn = tuple(
- global_output[2]) if output_hidden_states else tuple(
- global_output[1])
- attentions = tuple(output[2]) + global_attn
- else:
- attentions = None
-
- if not return_dict:
- return tuple(v for v in [hidden_state, hidden_states, attentions]
- if v is not None)
-
- return BaseModelOutput(
- last_hidden_state=hidden_state,
- hidden_states=hidden_states,
- attentions=attentions,
- )
def check_transformers():
- """check qwen_vl_utils."""
try:
from transformers import MllamaForConditionalGeneration # noqa: F401
except ImportError:
@@ -202,85 +20,60 @@ class MllamaVLModel(VisonModel):
_arch = 'MllamaForConditionalGeneration'
- def build_model(self):
- check_transformers()
-
- from transformers.models.mllama.modeling_mllama import \
- MllamaVisionModel
- MllamaVisionModel.forward = MllamaVisionModelPatch.forward
- MllamaVisionModel.apply_class_embedding = MllamaVisionModelPatch.apply_class_embedding # noqa
- from accelerate import init_empty_weights
- with init_empty_weights():
- config = self.hf_config
- config.quantization_config = {} # disable vision part quantization
- # disable accelerate check_tied_parameters_in_config
- config.tie_word_embeddings = False
- from transformers import MllamaForConditionalGeneration
- model = MllamaForConditionalGeneration._from_config(config)
- if not self.with_llm:
- del model.language_model
- else:
- self.vl_model = model
-
- from accelerate import load_checkpoint_and_dispatch
- with disable_logging():
- load_checkpoint_and_dispatch(
- model=model,
- checkpoint=self.model_path,
- device_map='auto' if not self.with_llm else {'': 'cpu'},
- max_memory=self.max_memory,
- no_split_module_classes=[
- 'MllamaPrecomputedPositionEmbedding',
- 'MllamaPrecomputedAspectRatioEmbedding',
- 'MllamaVisionEncoderLayer'
- ],
- dtype=config.torch_dtype)
-
- self.model = model.eval()
-
- # processor
+ def build_preprocessor(self):
from transformers import AutoProcessor
self.processor = AutoProcessor.from_pretrained(self.model_path)
self.image_token_id = 128256
- @torch.no_grad()
- def forward(self,
- images: List[Image],
- params: List[Dict] = None) -> List[torch.Tensor]:
- """forward."""
- # only support image input
- if params is not None:
- assert len(images) == len(
- params), 'different length of images and params'
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to the spec of `super().preprocess`"""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ results = self.processor.image_processor(images=image,
+ return_tensors='pt')
+ results.update(image_size=image.size,
+ image_tokens=1,
+ image_token_id=self.image_token_id)
+ outputs.append(results)
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
+
+ def build_model(self):
+ check_transformers()
+ if self.with_llm:
+ from transformers import MllamaForConditionalGeneration
+ model = MllamaForConditionalGeneration.from_pretrained(
+ self.model_path, device_map='cpu')
+ self.vl_model = model
else:
- params = [{}] * len(images)
- # resize images with abnormal shape
- # TODO try catch image feature extraction in pipeline and
- # throw error back to users
- for i, image in enumerate(images):
- size = image.size
- if any([s < 3 for s in size]):
- images[i] = image.resize([s * 3 for s in size])
- image_inputs = self.processor.image_processor(images=images,
- return_tensors='pt')
- pixel_values = image_inputs['pixel_values'].to(
- self.model.vision_model.device)
- pixel_values = pixel_values.type(self.model.vision_model.dtype)
- aspect_ratio_ids = image_inputs['aspect_ratio_ids'].to(
- self.model.vision_model.device)
- aspect_ratio_mask = image_inputs['aspect_ratio_mask'].to(
- self.model.vision_model.device)
- vision_outputs = self.model.vision_model(
- pixel_values=pixel_values,
- aspect_ratio_ids=aspect_ratio_ids,
- aspect_ratio_mask=aspect_ratio_mask,
- output_hidden_states=False,
- output_attentions=False,
- return_dict=True)
- cross_attention_states = vision_outputs[0]
- cross_attention_states = self.model.multi_modal_projector(
- cross_attention_states)
- _, bsz, _, _, image_token_dim = tuple(cross_attention_states.shape)
- cross_attention_states = cross_attention_states.view(
- bsz, -1, image_token_dim).split([1] * len(images))
- return cross_attention_states
+ raise NotImplementedError('turbomind has not supported mllama yet')
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = '<|image|>'
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
+ continue
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ content = [
+ item['text'] for item in message['content']
+ if item['type'] == 'text'
+ ]
+ prompt = (IMAGE_TOKEN) * n_images + content[0]
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/molmo.py b/lmdeploy/vl/model/molmo.py
index 9abae7a309..eccf62ebb6 100644
--- a/lmdeploy/vl/model/molmo.py
+++ b/lmdeploy/vl/model/molmo.py
@@ -3,11 +3,9 @@
from typing import Dict, List
import torch
-from PIL.Image import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from lmdeploy.utils import get_logger
-from lmdeploy.vl.constants import IMAGE_TOKEN
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import disable_logging
@@ -20,20 +18,26 @@ class MolmoVisionModel(VisonModel):
_arch = 'MolmoForCausalLM'
+ def build_preprocessor(self):
+ self.processor = AutoProcessor.from_pretrained(self.model_path,
+ trust_remote_code=True,
+ torch_dtype=torch.half,
+ device_map='auto')
+
def build_model(self):
- """Load model."""
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
with init_empty_weights():
- config = self.hf_config
- model = AutoModelForCausalLM.from_config(config,
+ model = AutoModelForCausalLM.from_config(self.hf_config,
trust_remote_code=True)
+
+ self.vl_model = model
if not self.with_llm:
# Remove nn modules other than embedding from the LLM model
for key in ['emb_drop', 'ln_f', 'blocks', 'ff_out']:
del model.model.transformer[key]
- self.token_embedding = model.model.transformer.wte
- else:
- self.vl_model = model
+ self.token_embedding = model.model.transformer.wte
with disable_logging():
load_checkpoint_and_dispatch(
@@ -43,118 +47,161 @@ def build_model(self):
max_memory=self.max_memory,
no_split_module_classes=[
'ResidualAttentionBlock', 'Embedding'
- ])
+ ],
+ dtype=torch.half)
# We need eval mode to freeze the weights in model, thus,
# avoid randomness in inference.
self.model = model.eval()
- self.config = config
- self.processor = AutoProcessor.from_pretrained(self.model_path,
- trust_remote_code=True,
- torch_dtype='auto',
- device_map='auto')
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to the `super.preprocess() for spec."""
+ for i, message in enumerate(messages):
+ if not isinstance(message['content'], List):
+ continue
+ images = [
+ x['image'] for x in message['content'] if x['type'] == 'image'
+ ]
+ content = [
+ x['text'] for x in message['content'] if x['type'] == 'text'
+ ]
+ prompt = f' User: {content[0]}'
+ tokens = self.processor.tokenizer.encode(prompt,
+ add_special_tokens=False)
+ # preprocess images. The output is a dict, which is
+ # {
+ # 'input_ids': torch.Tensor,
+ # 'images': torch.Tensor, # (n_patch, d_model)
+ # 'image_input_idx': torch.Tensor, # (n_patch, d_model)
+ # 'image_masks': torch.Tensor, # (n_patch, d_model)
+ # }
+ result = self.processor.process(images=images, tokens=tokens)
+ # remove the bos from input_ids which is prepended by molmo's
+ # processor
+ input_ids = result['input_ids'][1:]
+ result.update(input_ids=input_ids)
+ messages[i].update(preprocess=result)
+ return messages
@torch.no_grad()
def forward(self,
- images: List[Image],
- params: List[Dict] = None) -> List[Dict]:
- """forward the model with given input.
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
Args:
- images (List): [None] it is not used
- params (List): the inputs after precessing GPT4V messages in
- `MolmoChatTemplateWrapper`. Its format is like the following:
- [[
- {'role': 'user', 'content': 'user prompt'},
- {'role': 'asssistant', 'content': 'assistant prompt'},
- {'role': 'user', 'content': 'user prompt', 'images': [PIL image list]},
- ...
- ]]
- """ # noqa
-
- messages = params[0]
- assert isinstance(messages, List)
- # append an assistant message to `messages`
- messages.append(dict(role='assistant', content=''))
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ for i, message in enumerate(messages):
+ if 'preprocess' not in message.keys():
+ continue
+ inputs = message['preprocess']
+ # get input_ids of embedding
+ inputs = {
+ k: v.to(self.model.device).unsqueeze(0)
+ for k, v in inputs.items()
+ }
+ input_ids = inputs['input_ids']
+ # (batch_size, num_image, num_patch, d_model)
+ images = inputs['images']
+ # (batch_size, num_image, num_patch)
+ image_input_idx = inputs['image_input_idx']
+ image_masks = inputs['image_masks']
+ batch_size, seq_len = input_ids.size()
+ assert batch_size == 1
+ input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
+ embeddings = self.model.model.transformer.wte(input_ids)
+ images = images.to(self.model.dtype)
+ image_masks = image_masks.to(self.model.dtype)
+ logger.info(f'vision forward shape: {images.shape}')
+ image_features, _ = self.model.model.vision_backbone(
+ images, image_masks)
+ num_image, num_patch = image_features.shape[1:3]
+ assert image_input_idx.shape == (batch_size, num_image, num_patch)
+
+ # insert the image feature into the embedding.
+ image_features = image_features.view(batch_size,
+ num_image * num_patch, -1)
+ image_input_idx = image_input_idx.view(batch_size,
+ num_image * num_patch)
+ valid = image_input_idx >= 0
+ batch_idx = torch.arange(batch_size, device=embeddings.device)
+ batch_idx = torch.tile(batch_idx[:, None],
+ [1, image_features.shape[1]])
+ image_features = image_features.to(embeddings.device)
+ # Since we remove bos_id from input_ids during `preprocess`,
+ # the index `image_input_idx[valid]` should be shift to left
+ # by subtracting 1
+ index = image_input_idx[valid] - 1
+ embeddings[batch_idx[valid], index] += image_features[valid]
+ assert embeddings.shape[:2] == (batch_size, seq_len)
+ messages[i].update(
+ dict(forward=dict(input_ids=input_ids.flatten(),
+ embeddings=embeddings)))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages):
+ prompt = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ role, content = message['role'], message['content']
+ if isinstance(content, List):
+ n_images = len([1 for x in content if x['type'] == 'image'])
+ content = [x['text'] for x in content if x['type'] == 'text']
+ prompt.append(' User: ' + (IMAGE_TOKEN + '\n') * n_images +
+ content[0])
+ else:
+ if role == 'user':
+ prompt.append(f' User: {content}')
+ elif role == 'assistant':
+ prompt.append(f' Assistant:{content}')
+ else:
+ assert 0, f'molmo does not support role {role}, message is {message}' # noqa
+ prompt.append(' Assistant:')
+ return ''.join(prompt)
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ assert 0, 'molmo is not supported by pytorch engine'
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
# results is a list of tuple(input_ids, embeddings)
results = []
- # the concat prompt. It is not used during inference but to adhere the
- # interface definition of `_get_prompt_input` in `class VLAsyncEngine`
- prompts = ''
# Prepend BOS
# qwen2 and olmo do not have a BOS, and instead use EOS as a generic
# separator token.
bos = (self.processor.tokenizer.bos_token_id
or self.processor.tokenizer.eos_token_id)
results.append(([bos], None))
+
for i, message in enumerate(messages):
- if 'images' in message.keys():
- prompts += ' User: ' + (IMAGE_TOKEN + '\n') * len(
- message['images']) + message['content']
- prompt = f' User: {message["content"]}'
- tokens = self.processor.tokenizer.encode(
- prompt, add_special_tokens=False)
- # preprocess images. The output is a dict
- inputs = self.processor.process(images=message['images'],
- tokens=tokens)
- inputs = {
- k: v.to(self.model.device).unsqueeze(0)
- for k, v in inputs.items()
- }
- input_ids = inputs['input_ids']
- # remove the bos from input_ids which is prepended by molmo's
- # processor
- input_ids = input_ids[:, 1:]
- images = inputs[
- 'images'] # (batch_size, num_image, num_patch, d_model)
- image_input_idx = inputs[
- 'image_input_idx'] # (batch_size, num_image, num_patch)
- image_masks = inputs['image_masks']
- batch_size, seq_len = input_ids.size()
- assert batch_size == 1
-
- # Get embeddings of input.
- if input_ids is not None:
- input_ids = input_ids * (input_ids != -1).to(
- input_ids.dtype)
- embeddings = self.model.model.transformer.wte(input_ids)
- image_features, _ = self.model.model.vision_backbone(
- images, image_masks)
- num_image, num_patch = image_features.shape[1:3]
- assert image_input_idx.shape == (batch_size, num_image,
- num_patch)
-
- # insert the image feature into the embedding.
- image_features = image_features.view(batch_size,
- num_image * num_patch, -1)
- image_input_idx = image_input_idx.view(batch_size,
- num_image * num_patch)
-
- valid = image_input_idx >= 0
- batch_idx = torch.arange(batch_size, device=embeddings.device)
- batch_idx = torch.tile(batch_idx[:, None],
- [1, image_features.shape[1]])
- image_features = image_features.to(embeddings.device)
- embeddings[batch_idx[valid],
- image_input_idx[valid]] += image_features[valid]
- assert embeddings.shape[:2] == (batch_size, seq_len)
- results.append((input_ids.flatten().tolist(), embeddings))
+ prompt = ''
+ role, content = message['role'], message['content']
+ if isinstance(content, List):
+ forward_result = message.pop('forward')
+ input_ids = forward_result['input_ids']
+ embeddings = forward_result['embeddings']
+ results.append((input_ids.tolist(), embeddings))
else:
- role = message['role']
- content = message['content']
- assert isinstance(content, str)
- prompt = ''
if role == 'user':
prompt = f' User: {content}'
elif role == 'assistant':
prompt = f' Assistant:{content}'
else:
assert 0, f'molmo does not support role {role}, message is {message}' # noqa
+ if i == len(messages) - 1:
+ # the last message
+ assert role == 'user', f'the role of last message is expected to be user, but got {role}' # noqa
+ prompt += ' Assistant:'
+ if prompt:
input_ids = self.processor.tokenizer.encode(
prompt, add_special_tokens=False)
results.append((input_ids, None))
- prompts += prompt
# concat input_ids from results, calculate the range in the input_ids
# where embeddings will be copied to
@@ -169,9 +216,9 @@ def forward(self,
input_embedding_ranges.append((start, end))
input_ids += _input_ids
start += len(_input_ids)
- return [
- dict(prompt=prompts,
- input_ids=input_ids,
- input_embeddings=input_embeddings,
- input_embedding_ranges=input_embedding_ranges)
- ]
+
+ prompt = self.proc_messages(messages)
+ return dict(prompt=prompt,
+ input_ids=input_ids,
+ input_embeddings=input_embeddings,
+ input_embedding_ranges=input_embedding_ranges)
diff --git a/lmdeploy/vl/model/phi3_vision.py b/lmdeploy/vl/model/phi3_vision.py
index 032b8404da..ff00b5d1d9 100644
--- a/lmdeploy/vl/model/phi3_vision.py
+++ b/lmdeploy/vl/model/phi3_vision.py
@@ -1,198 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
-import warnings
-from typing import List
+from typing import Dict, List
-import torch
-from PIL.Image import Image
from transformers import AutoProcessor
-from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
-from lmdeploy.vl.model.utils import disable_logging
-
-
-# from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py # noqa E501
-def _process_image_embedding(self, pixel_values: torch.Tensor,
- image_sizes: torch.Tensor):
- """process image embedding."""
- img_embeds = pixel_values
- img_sizes = image_sizes
- target_device = pixel_values.device
- target_dtype = pixel_values.dtype
- if self.use_hd_transform and img_sizes is not None and len(img_sizes):
- assert img_embeds.ndim == 5, f'img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform' # noqa E501
- # img_embeds: (num_images, max_num_crops, 3, H, W)
- # img_sizes: (num_images, 2).view(1, -1)
-
- bs = img_embeds.shape[0]
- # Nx(HW)xC
- img_features = self.get_img_features(img_embeds.flatten(0, 1))
- base_feat_height = base_feat_width = int(img_features.shape[1]**0.5)
-
- assert base_feat_height == 24 and base_feat_width == 24, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect 24x24 features for hd transform' # noqa E501
-
- # bs x max_num_crops x (24x24) x C
- img_features = img_features.view(bs, -1,
- base_feat_height * base_feat_width,
- self.image_dim_out)
- C = self.image_dim_out
- H = base_feat_height
-
- output_imgs = []
- output_len = []
- # training is tensor, inference is list
- if isinstance(img_sizes, torch.Tensor):
- img_sizes = img_sizes.view(-1, 2)
- for _bs in range(bs):
- h, w = img_sizes[_bs]
- h = h // 336
- w = w // 336
- B_ = h * w
-
- # 1 x (24x24) x 1024
- global_img_feature = img_features[_bs, :1]
-
- # 1 x 12 x 12 x 4096
- glb_img = global_img_feature.reshape(1, H, H, C).reshape(
- 1, H // 2, 2, H // 2, 2,
- C).contiguous().permute(0, 1, 3, 2, 4,
- 5).reshape(1, H // 2, H // 2,
- 4 * C).contiguous()
- temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1)
-
- # 1 x 156 x 4096
- glb_img = torch.cat([glb_img, temp_glb_GN],
- dim=2).reshape(1, -1, 4 * C)
-
- # (max_num_crops-1) x (12x12) x C
- sub_img = img_features[_bs, 1:]
- # 16x574x1024
- # get rid of padding sub_img
- sub_img = sub_img[:B_]
-
- # (num_crops, 12, 2, 12, 2, 1024)->(num_crops, 12, 12, 2, 2, 1024)
- # -> (num_crops, 12*12, 4*1024)
- sub_img = sub_img.reshape(B_, H, H, C).reshape(
- B_, H // 2, 2, H // 2, 2,
- C).contiguous().permute(0, 1, 3, 2, 4,
- 5).reshape(B_, -1, 4 * C).contiguous()
- sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute(
- 0, 1, 3, 2, 4, 5).reshape(1, h * 12, w * 12, 4 * C)
- temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1)
- sub_img = torch.cat([sub_img, temp_sub_GN],
- dim=2).reshape(1, -1, 4 * C)
- # (1, num_img_tokens, 1024*4)
-
- # glb + sub
- if self.hd_transform_order == 'glb_sub':
- output_imgs.append(
- torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
- elif self.hd_transform_order == 'sub_glb':
- output_imgs.append(
- torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
- else:
- raise NotImplementedError(
- f'hd_transform_order = {self.hd_transform_order}'
- ) # noqa E501
-
- temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
- assert temp_len == output_imgs[-1].shape[
- 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' # noqa E501
- output_len.append(temp_len)
-
- img_set_tensor = []
- for _output_img in output_imgs:
- img_feature_proj = self.img_projection(
- _output_img.to(target_device).to(target_dtype))
- img_set_tensor.append(img_feature_proj)
- elif img_embeds.ndim == 4:
- tt = (self.get_img_features(img_embeds).to(target_device).to(
- target_dtype).reshape(-1, self.image_dim_out))
- img_set_tensor = self.img_projection(tt) # adapted visual features.
- elif img_embeds.ndim == 3:
- tt = (img_embeds.to(target_device).to(target_dtype).view(
- -1, self.image_dim_out))
- img_set_tensor = self.img_projection(tt) # adapted visual features.
- else:
- raise NotImplementedError
- return img_set_tensor
+from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel
@VISION_MODELS.register_module()
-class Phi3VisionModel(VisonModel):
- """Llava hf vision model."""
+class Phi3VisionModel(LlavaHfVisionModel):
+ """Phi3-vision model."""
_arch = 'Phi3VForCausalLM'
- def build_model(self):
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
- from accelerate.utils import get_balanced_memory, infer_auto_device_map
-
- with init_empty_weights(), warnings.catch_warnings():
- warnings.simplefilter('ignore')
- from transformers import AutoModelForCausalLM
- model = AutoModelForCausalLM.from_config(self.hf_config,
- trust_remote_code=True)
- if not self.with_llm:
- del model.lm_head
- del model.model.layers
- del model.model.norm
- del model.model.embed_tokens
- del model.model.vision_embed_tokens.wte
- else:
- self.vl_model = model
-
- no_split_module_classes = ['CLIPEncoderLayer']
- max_memory = get_balanced_memory(
- model,
- max_memory=self.max_memory,
- dtype=torch.half,
- no_split_module_classes=no_split_module_classes)
- device_map = infer_auto_device_map(
- model,
- no_split_module_classes=no_split_module_classes,
- max_memory=max_memory,
- dtype=torch.half)
- same_device_keys = [('model.vision_embed_tokens.img_projection',
- 'model.vision_embed_tokens.sub_GN',
- 'model.vision_embed_tokens.glb_GN')]
- for keys in same_device_keys:
- keys = [k for k in keys if k in device_map]
- if len(keys) <= 1:
- continue
- for k in keys[1:]:
- device_map[k] = device_map[keys[0]]
-
- with disable_logging():
- load_checkpoint_and_dispatch(
- model=model,
- checkpoint=self.model_path,
- device_map=device_map if not self.with_llm else {'': 'cpu'},
- no_split_module_classes=no_split_module_classes,
- dtype=torch.half)
-
- model.eval()
- self.model = model
- # processor
+ def build_preprocessor(self):
processor = AutoProcessor.from_pretrained(self.model_path,
trust_remote_code=True)
if hasattr(processor, 'tokenizer'):
del processor.tokenizer
- processor.prtokenizer = None
- self.processor = processor.image_processor
+ processor.tokenizer = None
self.processor = processor
- @torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- process_outputs = self.processor.image_processor(
- images, return_tensors='pt').to(device=self.model.device,
- dtype=self.model.dtype)
- pixel_values = process_outputs['pixel_values']
- image_sizes = process_outputs['image_sizes']
- image_features = _process_image_embedding(
- self.model.model.vision_embed_tokens,
- pixel_values=pixel_values,
- image_sizes=image_sizes)
- outputs = [x.squeeze() for x in image_features]
- return outputs
+ def build_model(self):
+ if self.with_llm:
+ from transformers import AutoModelForCausalLM
+ self.vl_model = AutoModelForCausalLM.from_pretrained(
+ self.model_path, device_map='cpu', trust_remote_code=True)
+ else:
+ raise NotImplementedError('turbomind has not supported phi3v yet')
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refers to `super.preprocess() for spec."""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ result = self.processor.image_processor(image, return_tensors='pt')
+ h = result['image_sizes'][0][0].item() // 336
+ w = result['image_sizes'][0][1].item() // 336
+ image_tokens = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
+ result.update(
+ dict(image_size=image.size,
+ image_tokens=image_tokens,
+ image_token_id=0))
+ outputs.append(result)
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
diff --git a/lmdeploy/vl/model/qwen.py b/lmdeploy/vl/model/qwen.py
index 3968f27d97..49631ccf35 100644
--- a/lmdeploy/vl/model/qwen.py
+++ b/lmdeploy/vl/model/qwen.py
@@ -1,14 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import List
+from typing import Dict, List
import torch
-from PIL.Image import Image
from transformers import AutoModelForCausalLM
+from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import disable_logging
+logger = get_logger('lmdeploy')
+
@VISION_MODELS.register_module()
class QwenVisionModel(VisonModel):
@@ -16,19 +18,33 @@ class QwenVisionModel(VisonModel):
_arch = 'QWenLMHeadModel'
+ def build_preprocessor(self):
+ from torchvision import transforms
+ from torchvision.transforms import InterpolationMode
+ mean = (0.48145466, 0.4578275, 0.40821073)
+ std = (0.26862954, 0.26130258, 0.27577711)
+ image_size = self.hf_config.visual['image_size']
+ self.image_transform = transforms.Compose([
+ transforms.Resize((image_size, image_size),
+ interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=mean, std=std),
+ ])
+
def build_model(self):
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights
with init_empty_weights():
config = self.hf_config
config.quantization_config = {} # disable vision part quantization
model = AutoModelForCausalLM.from_config(config,
trust_remote_code=True)
+ self.vl_model = model
if not self.with_llm:
del model.lm_head
for key in ['wte', 'h', 'ln_f']:
setattr(model.transformer, key, None)
- else:
- self.vl_model = model
from accelerate.utils import get_balanced_memory, infer_auto_device_map
max_memory = get_balanced_memory(
@@ -60,13 +76,86 @@ def build_model(self):
self.model = model.transformer.visual.eval()
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refers to `super.preprocess() for spec."""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ pixel_values = self.image_transform(image)
+ outputs.append(
+ dict(pixel_values=pixel_values,
+ image_size=image.size,
+ image_tokens=256,
+ image_token_id=0))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
+
@torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- outputs = [x.convert('RGB') for x in images]
- outputs = [self.model.image_transform(x) for x in outputs]
- outputs = torch.stack(outputs, dim=0)
- outputs = self.model(outputs)
- outputs = torch.split(outputs, 1, dim=0)
- outputs = [x.squeeze() for x in outputs]
- return outputs
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ pixel_values = torch.stack(pixel_values, dim=0)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ feats = self.model(pixel_values)
+ feats = torch.split(feats, 1, dim=0)
+ outputs.extend([x.squeeze() for x in feats])
+ messages.append(dict(role='forward', content=outputs))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
+ continue
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ content = [
+ x['text'] for x in message['content'] if x['type'] == 'text'
+ ]
+ prompt = content[0]
+ if IMAGE_TOKEN in prompt:
+ pass
+ else:
+ prompt = ''.join([
+ f'Picture {str(i)}:{IMAGE_TOKEN}\n'
+ for i in range(n_images)
+ ]) + prompt
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/qwen2.py b/lmdeploy/vl/model/qwen2.py
index 3eb3c1541c..ed9da332e0 100644
--- a/lmdeploy/vl/model/qwen2.py
+++ b/lmdeploy/vl/model/qwen2.py
@@ -1,12 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
-
from typing import Dict, List
import torch
-from PIL.Image import Image
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
-from lmdeploy.vl.model.utils import disable_logging
def check_qwen_vl_deps_install():
@@ -15,7 +12,7 @@ def check_qwen_vl_deps_install():
import qwen_vl_utils # noqa: F401
except ImportError:
raise ImportError(
- 'please install qwen_vl_utils by pip install qwen_vl_utils' # noqa: E501
+ 'please install qwen_vl_utils by `pip install qwen_vl_utils`' # noqa: E501
)
try:
from transformers import Qwen2VLForConditionalGeneration # noqa: F401
@@ -31,85 +28,105 @@ class Qwen2VLModel(VisonModel):
_arch = 'Qwen2VLForConditionalGeneration'
+ def build_preprocessor(self):
+ check_qwen_vl_deps_install()
+ from transformers import AutoProcessor
+ self.processor = AutoProcessor.from_pretrained(self.model_path)
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to `super().preprocess()` for spec."""
+ from qwen_vl_utils import process_vision_info
+
+ images = self.collect_images(messages)
+ optional_keys = {
+ 'resized_height', 'resized_width', 'min_pixels', 'max_pixels'
+ }
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+
+ item = dict(type='image', image=image)
+ item.update({
+ key: params[key]
+ for key in params.keys() if key in optional_keys
+ })
+ image_inputs, _ = process_vision_info([dict(content=[item])])
+ result = self.processor.image_processor(images=image_inputs,
+ videos=None,
+ return_tensors='pt')
+ merge_length = self.processor.image_processor.merge_size**2
+ image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length
+ result.update(
+ dict(image_size=image.size,
+ image_tokens=image_tokens,
+ image_token_id=0))
+ outputs.append(result)
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
+
def build_model(self):
check_qwen_vl_deps_install()
from transformers import Qwen2VLForConditionalGeneration
if self.with_llm:
- model = Qwen2VLForConditionalGeneration.from_pretrained(
- self.hf_config._name_or_path, trust_remote_code=True)
- model.half()
- self.vl_model = model
+ self.vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
+ self.model_path, device_map='cpu')
else:
- from accelerate import init_empty_weights
- with init_empty_weights():
- config = self.hf_config
- config.quantization_config = {
- } # disable vision part quantization
- # disable accelerate check_tied_parameters_in_config
- # for Qwen2-VL-2B-Instruct
- config.tie_word_embeddings = False
-
- model = Qwen2VLForConditionalGeneration._from_config(config)
- del model.model
- del model.lm_head
- model.half()
- from accelerate import load_checkpoint_and_dispatch
- with disable_logging():
- load_checkpoint_and_dispatch(
- model=model,
- checkpoint=self.model_path,
- device_map='auto' if not self.with_llm else {'': 'cpu'},
- max_memory=self.max_memory,
- no_split_module_classes=['Qwen2VLVisionBlock'],
- dtype=torch.half)
-
- self.model = model.eval()
-
- # processor
- from transformers import AutoProcessor
- self.processor = AutoProcessor.from_pretrained(self.model_path)
+ raise NotImplementedError(
+ 'turbomind has not supported qwen2-vl yet')
@torch.no_grad()
def forward(self,
- images: List[Image],
- params: List[Dict] = None) -> List[torch.Tensor]:
- """forward."""
- # only support image input
- if params is not None:
- assert len(images) == len(
- params), 'different length of images and params'
- else:
- params = [{}] * len(images)
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
- from qwen_vl_utils import process_vision_info
- images = [x.convert('RGB') for x in images]
- content = []
- optional_keys = [
- 'resized_height', 'resized_width', 'min_pixels', 'max_pixels'
- ]
- for image, param in zip(images, params):
- item = dict(type='image', image=image)
- item.update({k: param[k] for k in optional_keys if k in param})
- content.append(item)
- messages = [dict(content=content)]
- image_inputs, _ = process_vision_info(messages)
- image_inputs = self.processor.image_processor(images=image_inputs,
- videos=None,
- return_tensors='pt')
- pixel_values = image_inputs['pixel_values'].to(
- self.model.visual.get_device())
- image_grid_thw = image_inputs['image_grid_thw'].to(
- self.model.visual.get_device())
- pixel_values = pixel_values.type(self.model.visual.get_dtype())
- image_embeds = self.model.visual(pixel_values,
- grid_thw=image_grid_thw).cpu()
- merge_length = self.processor.image_processor.merge_size**2
- split_size = image_inputs['image_grid_thw'].prod(dim=1) // merge_length
- image_embeds = image_embeds.split(split_size.tolist())
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ assert 0, 'TODO: support turbomind engine'
- outputs = []
- for i, embeddings in enumerate(image_embeds):
- outputs.append(
- dict(embeddings=embeddings,
- grid_thw=image_inputs['image_grid_thw'][i].tolist()))
- return outputs
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
+ continue
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ content = [
+ item['text'] for item in message['content']
+ if item['type'] == 'text'
+ ]
+ prompt = content[0]
+ if IMAGE_TOKEN in prompt and '<|vision_start|>' not in prompt:
+ prompt = prompt.replace(
+ IMAGE_TOKEN,
+ f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>')
+ else:
+ # Qwen2-VL-2B-Instruct will concat image and user prompt
+ # according to their order in the content list
+ # we insert image token before user prompt by default. The
+ # user can use custom image token position if they want the
+ # same decorated prompt as Qwen2-VL
+ prompt = f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>' * \
+ n_images + prompt
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ """return to the information needed by pytorch engine."""
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/xcomposer2.py b/lmdeploy/vl/model/xcomposer2.py
index 96bc900c02..3c72d0c29f 100644
--- a/lmdeploy/vl/model/xcomposer2.py
+++ b/lmdeploy/vl/model/xcomposer2.py
@@ -5,7 +5,7 @@
import sys
import warnings
from contextlib import contextmanager
-from typing import Any, List, Tuple
+from typing import Any, Dict, List, Tuple
import torch
from PIL.Image import Image
@@ -19,6 +19,17 @@
logger = get_logger('lmdeploy')
+def check_xcomposer_install():
+ try:
+ # WARNING! we have to do this otherwise the model_type is wrong for
+ # xcomposer2d5
+ import decord # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "No module named 'decord'. Please install decord by `pip install decord`" # noqa
+ )
+
+
class ModelType(enum.Enum):
"""Request type."""
XCOMPOSER2 = enum.auto()
@@ -83,6 +94,17 @@ def init_empty_vit(model_path):
class Xcomposer2VisionModel(VisonModel):
"""InternLM-Xcomposer2 vision model."""
+ def __init__(self,
+ model_path: str,
+ with_llm: bool = False,
+ max_memory: Dict[int, int] = None,
+ hf_config: AutoConfig = None,
+ backend: str = ''):
+ super().__init__(model_path, with_llm, max_memory, hf_config, backend)
+ check_xcomposer_install()
+ self.model_type, self.module = get_xcomposer_type(self.model_path)
+ logger.info(f'matching type of {self.model_type}')
+
@classmethod
def match(cls, config: AutoConfig):
"""check whether the config match the model."""
@@ -94,7 +116,37 @@ def match(cls, config: AutoConfig):
return True
return False
+ def build_preprocessor(self):
+
+ import torchvision.transforms as transforms
+ from torchvision.transforms.functional import InterpolationMode
+
+ if self.model_type in [
+ ModelType.XCOMPOSER2D5, ModelType.XCOMPOSER2_4KHD
+ ]:
+ self.HD_transform = self.module
+ self.vis_processor = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
+ (0.26862954, 0.26130258, 0.27577711)),
+ ])
+ self.preprocess_func = (self._preprocess_2d5 if self.model_type
+ == ModelType.XCOMPOSER2D5 else
+ self._preprocess_4khd_7b)
+ else:
+ self.vis_processor = transforms.Compose([
+ transforms.Resize(
+ (self.hf_config.img_size, self.hf_config.img_size),
+ interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
+ (0.26862954, 0.26130258, 0.27577711)),
+ ])
+ self.preprocess_func = self._preprocess_7b
+
def build_model(self):
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights
with init_empty_weights(), warnings.catch_warnings(), \
init_empty_vit(self.model_path):
@@ -106,23 +158,10 @@ def build_model(self):
model.vit.resize_pos()
model.vit.vision_tower.vision_model.post_layernorm.to_empty(
device='cpu').half()
+ self.vl_model = model
if not self.with_llm:
del model.model
del model.output
- else:
- self.vl_model = model
-
- # additional components.
- model_type, module = get_xcomposer_type(self.model_path)
- logger.info(f'matching type of {model_type}')
- if model_type == ModelType.XCOMPOSER2D5:
- self.HD_transform = module
- self._forward_func = self._forward_2d5
- elif model_type == ModelType.XCOMPOSER2_4KHD:
- self.HD_transform = module
- self._forward_func = self._forward_4khd_7b
- else:
- self._forward_func = self._forward_7b
from accelerate.utils import get_balanced_memory, infer_auto_device_map
max_memory = get_balanced_memory(
@@ -156,51 +195,117 @@ def build_model(self):
self.model = model.eval()
- def _forward_2d5(self, images: List[Image]) -> List[torch.Tensor]:
- """internlm-xcomposer2d5-7b vit forward."""
- outputs = [x.convert('RGB') for x in images]
- hd_num = 6 if len(images) > 1 else 24
- outputs = [self.HD_transform(x, hd_num=hd_num) for x in outputs]
- outputs = [
- self.model.vis_processor(x).unsqueeze(0).to(dtype=torch.half)
- for x in outputs
- ]
- embeds, split = self.model.vit(outputs, self.model.plora_glb_GN,
- self.model.plora_sub_GN)
- embeds = self.model.vision_proj(embeds)
- embeds = torch.split(embeds, split, dim=1)
- embeds = [x.squeeze() for x in embeds]
- return embeds
-
- def _forward_7b(self, images: List[Image]) -> List[torch.Tensor]:
- """internlm-xcomposer2-7b vit forward."""
- outputs = [x.convert('RGB') for x in images]
- outputs = [
- self.model.vis_processor(x).unsqueeze(0).half() for x in outputs
- ]
- outputs = torch.cat(outputs, dim=0)
- outputs = self.model.vit(outputs)
- outputs = self.model.vision_proj(outputs)
- outputs = torch.split(outputs, 1, dim=0)
- outputs = [x.squeeze() for x in outputs]
- return outputs
-
- def _forward_4khd_7b(self, images: List[Image]) -> List[torch.Tensor]:
- """internlm-xcomposer2-4khd-7b vit forward."""
- outputs = [x.convert('RGB') for x in images]
- outputs = [self.HD_transform(x, hd_num=25) for x in outputs]
- outputs = [
- self.model.vis_processor(x).unsqueeze(0).to(dtype=torch.half)
- for x in outputs
- ]
- embeds, split = self.model.vit(outputs, self.model.plora_glb_GN,
- self.model.plora_sub_GN)
- embeds = self.model.vision_proj(embeds)
- embeds = torch.split(embeds, split, dim=1)
- embeds = [x.squeeze() for x in embeds]
- return embeds
+ def _preprocess_2d5(self, image: Image, params: Dict) -> Dict:
+ """image preprocessing for internlm-xcomposer2d5-7b."""
+ hd_num = params.get('hd_num', 24)
+ image = self.HD_transform(image, hd_num=hd_num)
+ pixel_values = self.vis_processor(image).unsqueeze(0).half()
+ w, h = image.size
+ n_token_per_image = int((h * w + 1) * 400 + 1 + (h + 1) * 20)
+ return pixel_values, n_token_per_image
+
+ def _preprocess_7b(self, image: Image, params: Dict) -> Dict:
+ """image preprocessing for internlm-xcomposer2-7b."""
+ pixel_values = self.vis_processor(image).unsqueeze(0).half()
+ return pixel_values, 256
+
+ def _preprocess_4khd_7b(self, image: Image, params: Dict) -> Dict:
+ """image preprocessing for internlm-xcomposer2-4khd-7b."""
+ image = self.HD_transform(image, hd_num=25)
+ pixel_values = self.vis_processor(image).unsqueeze(0).half()
+ w, h = image.size
+ n_token_per_image = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
+ return pixel_values, n_token_per_image
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to `super().preprocess() for spec."""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ pixel_values, n_token = self.preprocess_func(image, params)
+ outputs.append(
+ dict(pixel_values=pixel_values,
+ image_size=image.size,
+ image_tokens=n_token,
+ image_token_id=0))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
@torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- return self._forward_func(images)
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ if self.model_type in [
+ ModelType.XCOMPOSER2D5, ModelType.XCOMPOSER2_4KHD
+ ]:
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ embeds, split = self.model.vit(pixel_values,
+ self.model.plora_glb_GN,
+ self.model.plora_sub_GN)
+ embeds = self.model.vision_proj(embeds)
+ embeds = torch.split(embeds, split, dim=1)
+ embeds = [x.squeeze() for x in embeds]
+ else:
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ pixel_values = torch.cat(pixel_values, dim=0)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ embeds = self.model.vit(pixel_values)
+ embeds = self.model.vision_proj(embeds)
+ embeds = torch.split(embeds, 1, dim=0)
+ embeds = [x.squeeze() for x in embeds]
+ outputs.extend(embeds)
+ messages.append(dict(role='forward', content=outputs))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
+ continue
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ content = [
+ item['text'] for item in message['content']
+ if item['type'] == 'text'
+ ]
+ prompt = ' '.join([IMAGE_TOKEN] * n_images) + content[0]
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/yi.py b/lmdeploy/vl/model/yi.py
index 34b993322e..f8d3a907ff 100644
--- a/lmdeploy/vl/model/yi.py
+++ b/lmdeploy/vl/model/yi.py
@@ -1,12 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from contextlib import contextmanager
+from os import path as osp
+from typing import Dict, List
import torch.nn as nn
from transformers import AutoConfig
from lmdeploy.vl.model.base import VISION_MODELS
-from lmdeploy.vl.model.llava import LlavaVisionModel, check_llava_install
+from lmdeploy.vl.model.llava import (LlavaVisionModel, check_llava_install,
+ process_images)
from .utils import disable_transformers_logging, rewrite_ctx
@@ -96,8 +99,22 @@ def match(cls, config: AutoConfig):
return True
return False
+ def build_preprocessor(self):
+ from transformers import CLIPImageProcessor
+ vision_tower_name = osp.join(self.model_path,
+ self.hf_config.mm_vision_tower)
+ self.image_processor = CLIPImageProcessor.from_pretrained(
+ vision_tower_name)
+ config = AutoConfig.from_pretrained(vision_tower_name)
+ image_size = config.image_size
+ patch_size = config.patch_size
+ self.n_token_per_image = (image_size // patch_size)**2
+ if self.hf_config.mm_vision_select_feature == 'cls_patch':
+ self.n_token_per_image += 1
+
def build_model(self):
- """build model & load weights."""
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
check_llava_install()
global _model_path
@@ -105,3 +122,19 @@ def build_model(self):
with init_yi_model(), disable_transformers_logging():
super().build_model()
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to `super().preprocess() for spec."""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ pixel_values = process_images([image], self.image_processor,
+ self.config)
+ outputs.append(
+ dict(pixel_values=pixel_values,
+ image_size=image.size,
+ image_tokens=self.n_token_per_image,
+ image_token_id=0))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
diff --git a/lmdeploy/vl/templates.py b/lmdeploy/vl/templates.py
deleted file mode 100644
index cdf398868a..0000000000
--- a/lmdeploy/vl/templates.py
+++ /dev/null
@@ -1,550 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import asyncio
-from typing import Dict, List, Tuple, Union
-
-import PIL
-import PIL.Image
-
-from lmdeploy.archs import get_model_arch
-from lmdeploy.model import BaseModel
-from lmdeploy.utils import get_logger
-from lmdeploy.vl.constants import IMAGE_TOKEN
-from lmdeploy.vl.utils import load_image
-
-logger = get_logger('lmdeploy')
-
-VLPromptType = Union[str, Tuple[str, PIL.Image.Image],
- Tuple[str, List[PIL.Image.Image]]]
-
-
-class VLChatTemplateWrapper:
- """vl chat template wrapper."""
-
- def __init__(self, chat_template: BaseModel):
- self.chat_template = chat_template
-
- def prompt_to_messages(self, prompt: VLPromptType):
- """convert prompt to GTP4V format."""
- messages = {
- 'role': 'user',
- 'content': [{
- 'type': 'text',
- 'text': '',
- }]
- }
- if isinstance(prompt, str):
- messages['content'][0]['text'] = prompt
- else:
- prompt, images = prompt
- if not isinstance(images, list):
- images = [images]
- messages['content'][0]['text'] = prompt
- for image in images:
- # 'image_url': means url or local path to image.
- # 'image_data': means PIL.Image.Image object.
- if isinstance(image, str):
- image = load_image(image)
- item = {
- 'type': 'image_data',
- 'image_data': {
- 'data': image
- }
- }
- elif isinstance(image, PIL.Image.Image):
- item = {
- 'type': 'image_data',
- 'image_data': {
- 'data': image
- }
- }
- else:
- raise ValueError(
- 'image should be a str(url/path) or PIL.Image.Image')
-
- messages['content'].append(item)
-
- return [messages]
-
- async def async_collect_pil_images(
- self, messages: Dict) -> List[Tuple[PIL.Image.Image, Dict]]:
- """collect image from messages."""
- images_with_kwargs = []
- for message in messages:
- role = message['role']
- content = message['content']
- if role != 'user' or isinstance(content, str):
- continue
- for item in content:
- # 'image_url': means url or local path to image.
- # 'image_data': means PIL.Image.Image object.
- if item['type'] == 'image_url':
- item_copy = item['image_url'].copy()
- try:
- url = item_copy.pop('url')
- images_with_kwargs.append([url, item_copy])
- except KeyError:
- logger.error(f'invalid format {message}')
- elif item['type'] == 'image_data':
- item_copy = item['image_data'].copy()
- try:
- data = item_copy.pop('data')
- images_with_kwargs.append([data, item_copy])
- except KeyError:
- logger.error(f'invalid format {message}')
-
- def _inner_call(i, images):
- url_or_data = images[i][0]
- images[i][0] = load_image(url_or_data)
-
- await asyncio.gather(*[
- asyncio.get_event_loop().run_in_executor(None, _inner_call, i,
- images_with_kwargs)
- for i in range(len(images_with_kwargs))
- ])
-
- return images_with_kwargs
-
- def append_image_token(self, prompt, num_images: int):
- """append image token to user prompt."""
- if IMAGE_TOKEN in prompt:
- return prompt
- return (IMAGE_TOKEN + '\n') * num_images + prompt
-
- def convert_messages(self, messages, sequence_start=True):
- """convert GPT4V message format to GPT4 text format."""
- new_messages = []
- for message in messages:
- role = message['role']
- content = message['content']
- if role != 'user' or isinstance(content, str):
- if isinstance(content, list):
- text = content[0]['text']
- message = {'role': role, 'content': text}
- new_messages.append(message)
- continue
- num_images = 0
- for item in content:
- # 'image_url': means url or local path to image.
- # 'image_data': means PIL.Image.Image object.
- if item['type'] == 'image_url':
- num_images += 1
- elif item['type'] == 'image_data':
- num_images += 1
- elif item['type'] == 'text':
- prompt = item['text']
- if num_images > 0:
- # add IMAGE_TOKEN to user prompt
- prompt = self.append_image_token(prompt, num_images)
- new_item = {'role': 'user', 'content': prompt}
- new_messages.append(new_item)
- return new_messages
-
- def messages2prompt(self, messages, sequence_start=True, **kwargs) -> str:
- """convert messages to decorated prompt."""
- if isinstance(messages, str):
- return self.chat_template.messages2prompt(messages, sequence_start)
- new_messages = self.convert_messages(messages, sequence_start)
- return self.chat_template.messages2prompt(new_messages, sequence_start)
-
-
-class LlavaVLChatTemplateWrapper(VLChatTemplateWrapper):
- """Llava vl chat template."""
- pass
-
-
-class YiVLChatTemplateWrapper(VLChatTemplateWrapper):
- """Yi vl chat template."""
- pass
-
-
-class InternVLChatTemplateWrapper(VLChatTemplateWrapper):
- """InternVL chat template."""
-
- def append_image_token(self, prompt, num_images: int):
- """append image tokens to user prompt."""
- # lmdeploy uses as image token
- # internvl uses special tags
- if IMAGE_TOKEN in prompt and f'{IMAGE_TOKEN}' not in prompt:
- prompt = prompt.replace(f'{IMAGE_TOKEN}',
- f'{IMAGE_TOKEN}')
- prompt = prompt.replace('', '')
- prompt = prompt.replace('', '')
- prompt = prompt.replace('', '')
- elif IMAGE_TOKEN not in prompt:
- prompt = f'{IMAGE_TOKEN * num_images}\n' + prompt
- return prompt
-
-
-class DeepSeekVLChatTemplateWrapper(VLChatTemplateWrapper):
- """DeepSeek vl chat template."""
-
- def append_image_token(self, prompt, num_images: int):
- """append image tokens to user prompt."""
- if IMAGE_TOKEN in prompt:
- return prompt
- logger.error(
- f'for deepseek-vl model, the user should insert the {IMAGE_TOKEN} '
- 'to user prompt manually, please read https://lmdeploy.readthedocs'
- '.io/en/latest/inference/vl_pipeline.html for more details.')
- if num_images == 1:
- return f'{IMAGE_TOKEN}{prompt}'
- res = ''
- for i in range(num_images):
- res += f'{IMAGE_TOKEN} is Figure {str(i)}.\n'
- res = res + prompt
- return res
-
-
-class QwenVLChatTemplateWrapper(VLChatTemplateWrapper):
- """Qwen vl chat template."""
-
- def append_image_token(self, prompt, num_images: int):
- """append image tokens to user prompt."""
- if IMAGE_TOKEN in prompt:
- return prompt
- res = ''
- for i in range(num_images):
- res += f'Picture {str(i)}:{IMAGE_TOKEN}\n'
- res = res + prompt
- return res
-
-
-class Qwen2VLChatTemplateWrapper(VLChatTemplateWrapper):
- """qwen2 vl."""
-
- def append_image_token(self, prompt, num_images: int):
- """append image tokens to user prompt."""
- if IMAGE_TOKEN in prompt and '<|vision_start|>' not in prompt:
- prompt = prompt.replace(
- IMAGE_TOKEN, f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>')
- else:
- # Qwen2-VL-2B-Instruct will concat image and user prompt according
- # to their order in the content list
- # we insert image token before user prompt by default. The user can
- # use custom image token position if they want the same decorated
- # prompt as Qwen2-VL
- prompt = f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>' * \
- num_images + prompt
- return prompt
-
- def get_mrope_info(self,
- seq_len: int,
- grid_thws: List[Tuple[int, int, int]] = None,
- embedding_ranges: List[Tuple[int, int]] = None):
- import torch
- if grid_thws is None:
- mrope_position_ids = torch.arange(seq_len).expand(3, -1)
- mrope_position_delta = torch.tensor([0], dtype=torch.long)
- else:
- mrope_position_ids = [
- torch.arange(embedding_ranges[0][0]).expand(3, -1)
- ]
- st_idx = embedding_ranges[0][0]
- for i, (grid_thw, embedding_range) in enumerate(
- zip(grid_thws, embedding_ranges)):
- llm_grid_t, llm_grid_h, llm_grid_w = grid_thw
- llm_grid_h //= 2
- llm_grid_w //= 2
- t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
- -1, llm_grid_h * llm_grid_w).flatten()
- h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
- llm_grid_t, -1, llm_grid_w).flatten()
- w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
- llm_grid_t, llm_grid_h, -1).flatten()
- mrope_position_ids.append(
- torch.stack([t_index, h_index, w_index]) + st_idx)
- st_idx += max(llm_grid_h, llm_grid_w)
- if i < len(embedding_ranges) - 1:
- text_len = embedding_ranges[i +
- 1][0] - embedding_ranges[i][1]
- else:
- text_len = seq_len - embedding_range[1]
- mrope_position_ids.append(
- torch.arange(text_len).expand(3, -1) + st_idx)
- st_idx += text_len
- mrope_position_ids = torch.cat(mrope_position_ids, dim=-1)
- mrope_position_delta = torch.tensor([st_idx - seq_len],
- dtype=torch.long)
-
- return mrope_position_ids, mrope_position_delta
-
-
-class CogVLMChatTemplateWrapper(VLChatTemplateWrapper):
- """cogvlm chat template wrapper."""
-
- def __init__(self, chat_template: BaseModel):
- from lmdeploy.model import Vicuna
- self.chat_template = chat_template
- self.llm_chat_template = Vicuna(eoa=chat_template.eoa,
- stop_words=chat_template.stop_words)
-
- def convert_messages(self, messages, sequence_start=True):
- """convert GPT4V message format to GPT4 text format."""
- new_messages = []
- for message in messages:
- role = message['role']
- content = message['content']
- if role != 'user' or isinstance(content, str):
- new_messages.append(message)
- continue
- num_images = 0
- for item in content:
- if item['type'] == 'image_url':
- num_images += 1
- elif item['type'] == 'image_data':
- num_images += 1
- elif item['type'] == 'text':
- prompt = item['text']
-
- new_item = {
- 'role': 'user',
- 'content': prompt,
- 'num_images': num_images
- }
- new_messages.append(new_item)
- return new_messages
-
- def messages2prompt(self, messages, sequence_start=True, **kwargs) -> str:
- """convert messages to decorated prompt."""
- if isinstance(messages, str):
- return self.chat_template.messages2prompt(messages, sequence_start)
- new_messages = self.convert_messages(messages, sequence_start)
- prompt = ''
- for i, msg in enumerate(new_messages):
- num_images = msg.pop('num_images', 0)
- if num_images == 0:
- role = msg['role']
- msg = self.llm_chat_template.messages2prompt([msg],
- sequence_start
- and i == 0)
- msg = dict(role=role, content=msg)
- prompt_i = self.chat_template.messages2prompt([msg], sequence_start
- and i == 0)
- if num_images > 0:
- prompt_i = (IMAGE_TOKEN * num_images) + prompt_i
- prompt += prompt_i
- return prompt
-
-
-class InternLMXComposer2TemplateWrapper(VLChatTemplateWrapper):
- """InternLM-XComposer2 chat template."""
-
- def append_image_token(self, prompt, num_images: int):
- if IMAGE_TOKEN in prompt:
- return prompt
- logger.warning(f'auto append {IMAGE_TOKEN} at the beginning, '
- 'the user can manually insert the token to prompt')
- return ' '.join([IMAGE_TOKEN] * num_images) + prompt
-
-
-class MiniGeminiLlamaTempateWrapper(VLChatTemplateWrapper):
- """Qwen vl chat template."""
-
- def append_image_token(self, prompt, num_images: int):
- """append image tokens to user prompt."""
- if num_images == 0:
- return prompt
- if IMAGE_TOKEN in prompt:
- return prompt
- res = f'{IMAGE_TOKEN}\n'
- assert num_images <= 1, 'MiniGeminiLlama accepts 1 input image'
- res = res + prompt
- return res
-
-
-class MllamaTempateWrapper(VLChatTemplateWrapper):
- """Mllama chat template."""
-
- def append_image_token(self, prompt, num_images: int):
- """append image tokens to user prompt."""
- return f'{IMAGE_TOKEN * num_images}{prompt}'
-
-
-class MiniCPMVTempateWrapper(VLChatTemplateWrapper):
- """MiniCPM-Llama3-V-2_5 chat template."""
-
- def append_image_token(self, prompt, num_images: int):
- if IMAGE_TOKEN in prompt:
- return prompt
- prompt = f'{IMAGE_TOKEN}\n' * num_images + prompt
- return prompt
-
- def update_image_token(self, prompt, features):
- _features = []
- _prompt = []
- segs = prompt.split(f'{IMAGE_TOKEN}\n')
- for i, seg in enumerate(segs):
- if i > 0 and i <= len(features):
- _feat = features[i - 1]['embeddings'].split(1)
- _feat = [x.squeeze() for x in _feat]
- _features.extend(_feat)
- _seg = f'{IMAGE_TOKEN}'
- if len(_feat) > 1:
- grid = features[i - 1]['grid']
- if grid is not None:
- _slice = '\n'.join(
- [f'{IMAGE_TOKEN}' * grid[0]] *
- grid[1])
- _seg = f'{_seg}{_slice}\n'
- _prompt.append(_seg)
- _prompt.append(seg)
- _prompt = ''.join(_prompt)
- return _prompt, _features
-
-
-class MiniCPMV26TempateWrapper(MiniCPMVTempateWrapper):
- """MiniCPM-V-2_6 chat template."""
-
- def update_image_token(self, prompt, features):
- _features = []
- _prompt = []
- segs = prompt.split(f'{IMAGE_TOKEN}\n')
- idx = 0
- for i, seg in enumerate(segs):
- if i > 0 and i <= len(features):
- _feat = features[i - 1]['embeddings'].split(1)
- _feat = [x.squeeze() for x in _feat]
- _features.extend(_feat)
- _seg = f'{IMAGE_TOKEN}'
- if features[i - 1].get('use_image_id', False):
- _seg = f'{idx}' + _seg
- idx += 1
- if len(_feat) > 1:
- grid = features[i - 1]['grid']
- if grid is not None:
- _slice = '\n'.join(
- [f'{IMAGE_TOKEN}' * grid[0]] *
- grid[1])
- _seg = _seg + _slice
- _seg += '\n'
- _prompt.append(_seg)
- _prompt.append(seg)
- _prompt = ''.join(_prompt)
- return _prompt, _features
-
-
-class GLM4VChatTemplateWrapper(VLChatTemplateWrapper):
- """glm-4v chat template."""
- pass
-
-
-class MolmoChatTemplateWrapper(VLChatTemplateWrapper):
-
- async def async_collect_pil_images(
- self, messages: List[Dict]) -> List[Tuple[PIL.Image.Image, Dict]]:
- """collect images from messages.
-
- Args:
- messages (List[Dict]): a user request of GPT4V message format
- """
- if isinstance(messages, Dict):
- messages = [messages]
- assert isinstance(messages, List)
-
- out_messages = [None] * len(messages)
-
- def _inner_call(i, in_messages, out_messages):
- role = in_messages[i]['role']
- content = in_messages[i]['content']
- if role != 'user' or isinstance(content, str):
- # means message is user's prompt input or assistant's prompt,
- # returning it directory
- out_messages[i] = in_messages[i]
- return
- # the role is a user and the content is a list
- assert isinstance(content, List)
- message = dict(role=role, content='', images=[])
- for item in content:
- # 'image_url': means url or local path to image.
- # 'image_data': means PIL.Image.Image object.
- if item['type'] == 'image_url':
- try:
- image = load_image(item['image_url']['url'])
- message['images'].append(image)
- except KeyError:
- logger.error(f'invalid format {message}')
- elif item['type'] == 'image_data':
- try:
- image = load_image(item['image_data']['data'])
- message['images'].append(image)
- except KeyError:
- logger.error(f'invalid format {message}')
- elif item['type'] == 'text':
- message['content'] = item['text']
- else:
- logger.error(f'unexpected content type {message}')
- out_messages[i] = message
-
- await asyncio.gather(*[
- asyncio.get_event_loop().run_in_executor(None, _inner_call, i,
- messages, out_messages)
- for i in range(len(messages))
- ])
- return [(None, out_messages)]
-
- def messages2prompt(self, messages, sequence_start=True, **kwargs) -> str:
- """Return a placeholder "IMAGE_TOKEN" so that
- `vl_asyn_engine._get_prompt_input` can know that it."""
- if isinstance(messages, str):
- return self.chat_template.messages2prompt(messages, sequence_start)
- else:
- _messages = []
- for message in messages:
- role, content = message['role'], message['content']
- if role != 'user' or isinstance(content, str):
- _messages.append(message)
- continue
- for item in content:
- item_type = item['type']
- if item_type in ['image_url', 'image_data']:
- # Return the image placeholder so that
- # `vl_asyn_engine._get_prompt_input` can know that the
- # request contains images
- return IMAGE_TOKEN
- _messages.append(dict(role=role, content=item[item_type]))
- return self.chat_template.messages2prompt(_messages,
- sequence_start)
-
-
-def get_vl_prompt_template(model_path: str, chat_template: BaseModel,
- model_name: str) -> VLChatTemplateWrapper:
- """get vision language prompt template."""
- assert type(chat_template) != type(BaseModel()), 'failed to match ' \
- 'chat template, please explicit set chat_template_config' # noqa E721
- if model_name == 'yi-vl':
- return YiVLChatTemplateWrapper(chat_template)
- arch, cfg = get_model_arch(model_path)
- if arch == 'QWenLMHeadModel':
- return QwenVLChatTemplateWrapper(chat_template)
- elif arch in [
- 'LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM',
- 'LlavaForConditionalGeneration',
- 'LlavaNextForConditionalGeneration', 'Phi3VForCausalLM'
- ]:
- return LlavaVLChatTemplateWrapper(chat_template)
- elif arch == 'MultiModalityCausalLM': # deepseek-vl
- return DeepSeekVLChatTemplateWrapper(chat_template)
- elif arch == 'MllamaForConditionalGeneration': # llama 3.2
- return MllamaTempateWrapper(chat_template)
- elif arch == 'CogVLMForCausalLM':
- return CogVLMChatTemplateWrapper(chat_template)
- elif arch in ['InternLMXComposer2ForCausalLM', 'InternLM2ForCausalLM']:
- return InternLMXComposer2TemplateWrapper(chat_template)
- elif arch == 'InternVLChatModel':
- return InternVLChatTemplateWrapper(chat_template)
- elif arch in ['MiniGeminiLlamaForCausalLM', 'MGMLlamaForCausalLM']:
- return MiniGeminiLlamaTempateWrapper(chat_template)
- elif arch == 'MiniCPMV':
- version_map = {
- '2.5': MiniCPMVTempateWrapper,
- '2.6': MiniCPMV26TempateWrapper
- }
- version = str(getattr(cfg, 'version', '2.5'))
- return version_map[version](chat_template)
- elif arch == 'ChatGLMModel':
- return GLM4VChatTemplateWrapper(chat_template)
- elif arch == 'Qwen2VLForConditionalGeneration':
- return Qwen2VLChatTemplateWrapper(chat_template)
- elif arch == 'MolmoForCausalLM':
- return MolmoChatTemplateWrapper(chat_template)
- raise ValueError(f'unsupported vl_prompt_template with arch {arch}')
diff --git a/requirements/runtime_ascend.txt b/requirements/runtime_ascend.txt
index 05d74bbe72..81f538275c 100644
--- a/requirements/runtime_ascend.txt
+++ b/requirements/runtime_ascend.txt
@@ -1,5 +1,5 @@
accelerate>=0.29.3
-dlinfer-ascend>=0.1.2
+dlinfer-ascend>=0.1.3
einops
fastapi
fire
@@ -16,7 +16,8 @@ safetensors
sentencepiece
shortuuid
tiktoken
-torch<=2.4.0,>=2.0.0
-torchvision<=0.19.0,>=0.15.0
+torch<=2.4.0,>=2.3.1
+torch-npu==2.3.1
+torchvision<=0.19.0,>=0.18.1
transformers
uvicorn
diff --git a/requirements/runtime.txt b/requirements/runtime_cuda.txt
similarity index 82%
rename from requirements/runtime.txt
rename to requirements/runtime_cuda.txt
index ec4957608c..41af6039bd 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime_cuda.txt
@@ -15,8 +15,8 @@ safetensors
sentencepiece
shortuuid
tiktoken
-torch<=2.4.0,>=2.0.0
-torchvision<=0.19.0,>=0.15.0
+torch<=2.5.1,>=2.0.0
+torchvision<=0.20.1,>=0.15.0
transformers
triton==3.0.0; sys_platform == "linux"
uvicorn
diff --git a/requirements/runtime_maca.txt b/requirements/runtime_maca.txt
new file mode 100644
index 0000000000..f65b3827cd
--- /dev/null
+++ b/requirements/runtime_maca.txt
@@ -0,0 +1,22 @@
+accelerate==0.32.1
+einops
+fastapi
+fire
+mmengine-lite
+numpy<2.0.0
+openai
+outlines<0.1.0
+peft<=0.11.1
+pillow
+protobuf
+pydantic>2.0.0
+pynvml
+safetensors
+sentencepiece
+shortuuid
+tiktoken
+torch<=2.4.0,>=2.0.0
+torchvision<=0.19.0,>=0.15.0
+transformers
+triton>=2.1.0; sys_platform == "linux"
+uvicorn
diff --git a/requirements.txt b/requirements_cuda.txt
similarity index 70%
rename from requirements.txt
rename to requirements_cuda.txt
index 91d38808f1..7c1d387dfb 100644
--- a/requirements.txt
+++ b/requirements_cuda.txt
@@ -1,4 +1,4 @@
-r requirements/build.txt
--r requirements/runtime.txt
+-r requirements/runtime_cuda.txt
-r requirements/lite.txt
-r requirements/serve.txt
diff --git a/requirements_maca.txt b/requirements_maca.txt
new file mode 100644
index 0000000000..075b132c8c
--- /dev/null
+++ b/requirements_maca.txt
@@ -0,0 +1,4 @@
+-r requirements/build.txt
+-r requirements/runtime_maca.txt
+-r requirements/lite.txt
+-r requirements/serve.txt
diff --git a/setup.py b/setup.py
index 7a08ac7919..52e180d8a2 100644
--- a/setup.py
+++ b/setup.py
@@ -4,18 +4,14 @@
from setuptools import find_packages, setup
-npu_available = False
-try:
- import torch_npu
-
- npu_available = torch_npu.npu.is_available()
-except ImportError:
- pass
-
pwd = os.path.dirname(__file__)
version_file = 'lmdeploy/version.py'
+def get_target_device():
+ return os.getenv('LMDEPLOY_TARGET_DEVICE', 'cuda')
+
+
def readme():
with open(os.path.join(pwd, 'README.md'), encoding='utf-8') as f:
content = f.read()
@@ -154,16 +150,12 @@ def gen_packages_items():
setup_requires=parse_requirements('requirements/build.txt'),
tests_require=parse_requirements('requirements/test.txt'),
install_requires=parse_requirements(
- 'requirements/runtime_ascend.txt'
- if npu_available else 'requirements/runtime.txt'),
+ f'requirements/runtime_{get_target_device()}.txt'),
extras_require={
'all':
- parse_requirements('requirements_ascend.txt'
- if npu_available else 'requirements.txt'),
- 'lite':
- parse_requirements('requirements/lite.txt'),
- 'serve':
- parse_requirements('requirements/serve.txt')
+ parse_requirements(f'requirements_{get_target_device()}.txt'),
+ 'lite': parse_requirements('requirements/lite.txt'),
+ 'serve': parse_requirements('requirements/serve.txt')
},
has_ext_modules=check_ext_modules,
classifiers=[
diff --git a/src/turbomind/kernels/gemm/moe_utils_v2.cu b/src/turbomind/kernels/gemm/moe_utils_v2.cu
index a9e4f7da51..44fec67748 100644
--- a/src/turbomind/kernels/gemm/moe_utils_v2.cu
+++ b/src/turbomind/kernels/gemm/moe_utils_v2.cu
@@ -2,6 +2,7 @@
#include
#include
+#include
#include
#include
#include
diff --git a/src/turbomind/kernels/gemm/test/test_utils.cu b/src/turbomind/kernels/gemm/test/test_utils.cu
index 8f2b4007f6..8ee595ab9b 100644
--- a/src/turbomind/kernels/gemm/test/test_utils.cu
+++ b/src/turbomind/kernels/gemm/test/test_utils.cu
@@ -84,7 +84,7 @@ FastCompare(const T* src, const T* ref, int dims, int bsz, cudaStream_t stream,
thrust::cuda::par.on(stream),
zip_iter,
zip_iter + count,
- [=] __device__(auto tup) {
+ [=] __host__ __device__(thrust::tuple tup) -> Tuple {
float s = thrust::get<0>(tup);
float r = thrust::get<1>(tup);
float abs_diff = fabsf(s - r);
diff --git a/tests/pytorch/engine/test_request.py b/tests/pytorch/engine/test_request.py
index 813a30e8e7..68ef6b9db9 100644
--- a/tests/pytorch/engine/test_request.py
+++ b/tests/pytorch/engine/test_request.py
@@ -3,7 +3,7 @@
import pytest
from lmdeploy.pytorch.engine.request import (RequestManager, RequestType,
- Response, ResponseType)
+ ResponseType)
class TestRequestHander:
@@ -17,36 +17,31 @@ def event_loop(self):
asyncio.set_event_loop(old_loop)
@pytest.fixture
- def thread_safe(self, request):
- yield request.param
+ def manager(self):
+ yield RequestManager()
- @pytest.fixture
- def manager(self, thread_safe):
- yield RequestManager(thread_safe=thread_safe)
-
- @pytest.mark.parametrize('thread_safe', [True, False])
def test_bind(self, manager, event_loop):
def __stop_engine_callback(reqs, **kwargs):
for req in reqs:
- manager.response(
- Response(type=ResponseType.SUCCESS,
- sender_id=req.sender_id,
- req_id=req.req_id,
- data=f'{req.data} success'))
+ resp = req.resp
+ resp.type = ResponseType.SUCCESS
+ resp.data = f'{req.data} success'
+ manager.response(resp)
async def __dummy_loop():
while True:
- manager.step()
- await asyncio.sleep(0.1)
+ try:
+ await manager.step()
+ except Exception:
+ return
- asyncio.set_event_loop(event_loop)
sender = manager.build_sender()
manager.start_loop(__dummy_loop)
# test not bind
- req_id = sender.send_async(RequestType.STOP_ENGINE, None)
- resp = sender.recv(req_id)
+ resp = sender.send_async(RequestType.STOP_ENGINE, None)
+ resp = sender.recv(resp)
assert resp.type == ResponseType.HANDLER_NOT_EXIST
assert manager.is_loop_alive()
@@ -54,6 +49,8 @@ async def __dummy_loop():
# test bind success
sender.send_async(RequestType.STOP_ENGINE, None)
manager.bind_func(RequestType.STOP_ENGINE, __stop_engine_callback)
- req_id = sender.send_async(RequestType.STOP_ENGINE, 'test')
- resp = sender.recv(req_id)
+ resp = sender.send_async(RequestType.STOP_ENGINE, 'test')
+ resp = sender.recv(resp)
assert resp.data == 'test success'
+
+ manager.stop_loop()
diff --git a/tests/pytorch/kernel/test_fuse_moe_blocked_fp8.py b/tests/pytorch/kernel/test_fuse_moe_blocked_fp8.py
new file mode 100644
index 0000000000..bb165658dd
--- /dev/null
+++ b/tests/pytorch/kernel/test_fuse_moe_blocked_fp8.py
@@ -0,0 +1,231 @@
+import pytest
+import torch
+
+
+def _make_A(M, K, group_size, out_dtype, device='cuda'):
+ quant_A = torch.rand(M,
+ K // group_size,
+ group_size,
+ dtype=torch.float32,
+ device=device)
+ # -1 ~ 1
+ quant_A = quant_A * 2 - 1
+ # scaling abs max to fmax
+ finfo = torch.finfo(out_dtype)
+ fmax = finfo.max
+ scaling = fmax / quant_A.abs().amax(-1, keepdim=True)
+ quant_A *= scaling
+ quant_A = quant_A.to(out_dtype).to(torch.float32)
+
+ # create scale and A
+ scale = torch.rand(M, K // group_size, dtype=torch.float32, device=device)
+ scale /= fmax
+ A = quant_A * scale[..., None]
+
+ A = A.reshape(M, K)
+ quant_A = quant_A.reshape(M, K).to(out_dtype)
+ return A, quant_A, scale
+
+
+def _make_B(E, K, N, group_size, out_dtype, device='cuda'):
+ quant_B = torch.rand(E,
+ N // group_size,
+ group_size,
+ K // group_size,
+ group_size,
+ dtype=torch.float32,
+ device=device)
+ quant_B = quant_B * 2 - 1
+
+ # scaling abs max to fmax
+ finfo = torch.finfo(out_dtype)
+ fmax = finfo.max
+ scaling = fmax / quant_B.abs().amax((2, 4), keepdim=True)
+ quant_B *= scaling
+ quant_B = quant_B.to(out_dtype).to(torch.float32)
+
+ scale = torch.rand(E,
+ N // group_size,
+ 1,
+ K // group_size,
+ 1,
+ dtype=torch.float32,
+ device=device)
+ scale /= fmax
+
+ B = quant_B * scale
+
+ B = B.reshape(E, N, K)
+ quant_B = quant_B.reshape(E, N, K).to(out_dtype)
+ scale = scale.reshape(E, N // group_size, K // group_size)
+ return B, quant_B, scale
+
+
+@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9,
+ reason='require device with cc>=9.0')
+class TestFusedMoeBlockedFP8:
+
+ @pytest.fixture
+ def dtype(self):
+ yield torch.float16
+
+ @pytest.fixture
+ def quant_dtype(self):
+ yield torch.float8_e4m3fn
+
+ @pytest.fixture
+ def device(self):
+ yield torch.device('cuda')
+
+ @pytest.fixture
+ def in_size(self):
+ yield 512
+
+ @pytest.fixture
+ def seq_len(seq_len):
+ yield 128
+
+ @pytest.fixture
+ def hidden_size(self):
+ yield 2048
+
+ @pytest.fixture
+ def out_size(self):
+ yield 1024
+
+ @pytest.fixture
+ def num_experts(self):
+ yield 4
+
+ @pytest.fixture
+ def top_k(self):
+ yield 2
+
+ @pytest.fixture
+ def group_size(self):
+ yield 128
+
+ @pytest.fixture
+ def renormalize(self):
+ yield True
+
+ @pytest.fixture
+ def build_hidden_states(self, seq_len, in_size, group_size, quant_dtype,
+ device):
+ yield _make_A(seq_len,
+ in_size,
+ group_size=group_size,
+ out_dtype=quant_dtype,
+ device=device)
+
+ @pytest.fixture
+ def hidden_states(self, build_hidden_states, dtype):
+ yield build_hidden_states[0].to(dtype)
+
+ @pytest.fixture
+ def states_quanted(self, build_hidden_states):
+ yield build_hidden_states[1]
+
+ @pytest.fixture
+ def states_scale(self, build_hidden_states):
+ yield build_hidden_states[2]
+
+ @pytest.fixture
+ def build_w1(self, num_experts, hidden_size, in_size, group_size,
+ quant_dtype, device):
+ yield _make_B(num_experts,
+ in_size,
+ hidden_size,
+ group_size=group_size,
+ out_dtype=quant_dtype,
+ device=device)
+
+ @pytest.fixture
+ def w1(self, build_w1, dtype):
+ yield build_w1[0].to(dtype)
+
+ @pytest.fixture
+ def w1_quant(self, build_w1):
+ yield build_w1[1]
+
+ @pytest.fixture
+ def w1_scale(self, build_w1):
+ yield build_w1[2]
+
+ @pytest.fixture
+ def build_w2(self, num_experts, out_size, hidden_size, group_size,
+ quant_dtype, device):
+ yield _make_B(num_experts,
+ hidden_size // 2,
+ out_size,
+ group_size=group_size,
+ out_dtype=quant_dtype,
+ device=device)
+
+ @pytest.fixture
+ def w2(self, build_w2, dtype):
+ yield build_w2[0].to(dtype)
+
+ @pytest.fixture
+ def w2_quant(self, build_w2):
+ yield build_w2[1]
+
+ @pytest.fixture
+ def w2_scale(self, build_w2):
+ yield build_w2[2]
+
+ @pytest.fixture
+ def router_logits(self, seq_len, num_experts, dtype, device):
+ yield torch.rand(seq_len, num_experts, dtype=dtype, device=device)
+
+ @pytest.fixture
+ def topk_logits(self, router_logits, top_k):
+ routing_weights = torch.softmax(router_logits,
+ dim=-1,
+ dtype=torch.float32)
+ yield torch.topk(routing_weights, top_k, dim=-1)
+
+ @pytest.fixture
+ def topk_weights(self, topk_logits):
+ yield topk_logits[0]
+
+ @pytest.fixture
+ def topk_idx(self, topk_logits):
+ yield topk_logits[1]
+
+ @pytest.fixture
+ def gt(self, hidden_states, w1, w2, topk_weights, topk_idx, top_k,
+ renormalize):
+ from lmdeploy.pytorch.kernels.cuda.fused_moe import fused_moe
+ output = fused_moe(hidden_states,
+ w1,
+ w2,
+ topk_weights,
+ topk_idx,
+ topk=top_k,
+ renormalize=renormalize)
+ yield output
+
+ @torch.inference_mode()
+ def test_fused_moe(self, states_quanted, states_scale, w1_quant, w1_scale,
+ w2_quant, w2_scale, topk_weights, topk_idx, top_k,
+ renormalize, gt):
+ from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import \
+ fused_moe_blocked_fp8
+ output = fused_moe_blocked_fp8(states_quanted,
+ states_scale,
+ w1_quant,
+ w1_scale,
+ w2_quant,
+ w2_scale,
+ topk_weights,
+ topk_idx,
+ topk=top_k,
+ renormalize=renormalize)
+ out_max = output.abs().max()
+ gt_max = gt.abs().max()
+ assert (out_max - gt_max).abs() / out_max < 0.05
+
+ norm_out = output / out_max
+ norm_gt = gt / gt_max
+ torch.testing.assert_close(norm_out, norm_gt, atol=0.05, rtol=1e-3)
diff --git a/tests/pytorch/kernel/test_fused_moe.py b/tests/pytorch/kernel/test_fused_moe.py
index 55e3a75c08..cc309eb6a7 100644
--- a/tests/pytorch/kernel/test_fused_moe.py
+++ b/tests/pytorch/kernel/test_fused_moe.py
@@ -250,3 +250,54 @@ def test_fused_moe(self, hidden_states, w1, w2, topk_weights, topk_idx,
topk=top_k,
renormalize=renormalize)
torch.testing.assert_close(output, gt, atol=1e-3, rtol=1e-3)
+
+
+class TestFusedMoeW8A8(TestFusedMoe):
+
+ @pytest.fixture
+ def quant_states(self, hidden_states):
+ from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import \
+ per_token_quant_int8
+ states_i8, states_scale = per_token_quant_int8(hidden_states, 1e-7)
+ yield states_i8, states_scale
+
+ def quant_weight(self, w):
+ from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import \
+ per_channel_quant
+ num_experts, num_outs, _ = w.shape
+ w = w.flatten(0, -2)
+ w_i8, w_scale = per_channel_quant(w, torch.int8)
+ w_i8 = w_i8.view(num_experts, num_outs, -1)
+ w_scale = w_scale.view(num_experts, num_outs, -1)
+ return w_i8, w_scale
+
+ @pytest.fixture
+ def quant_w1(self, w1):
+ w_i8, w_scale = self.quant_weight(w1)
+ yield w_i8, w_scale
+
+ @pytest.fixture
+ def quant_w2(self, w2):
+ w_i8, w_scale = self.quant_weight(w2)
+ yield w_i8, w_scale
+
+ @torch.inference_mode()
+ def test_fused_moe(self, quant_states, quant_w1, quant_w2, topk_weights,
+ topk_idx, top_k, renormalize, gt):
+ from lmdeploy.pytorch.kernels.cuda.w8a8_fused_moe import fused_moe_w8a8
+ state_i8, state_scale = quant_states
+ w1_i8, w1_scale = quant_w1
+ w2_i8, w2_scale = quant_w2
+
+ output = fused_moe_w8a8(state_i8,
+ state_scale,
+ w1_i8,
+ w1_scale,
+ w2_i8,
+ w2_scale,
+ topk_weights=topk_weights,
+ topk_ids=topk_idx,
+ topk=top_k,
+ out_dtype=torch.float16,
+ renormalize=renormalize)
+ torch.testing.assert_close(output, gt, atol=5e-3, rtol=1e-3)
diff --git a/tests/pytorch/kernel/test_gemm_fp8.py b/tests/pytorch/kernel/test_gemm_fp8.py
new file mode 100644
index 0000000000..242a2db581
--- /dev/null
+++ b/tests/pytorch/kernel/test_gemm_fp8.py
@@ -0,0 +1,193 @@
+import pytest
+import torch
+
+
+def _make_A(M, K, group_size, out_dtype):
+ quant_A = torch.rand(M,
+ K // group_size,
+ group_size,
+ dtype=torch.float32,
+ device='cuda')
+ # -1 ~ 1
+ quant_A = quant_A * 2 - 1
+ # scaling abs max to fmax
+ finfo = torch.finfo(out_dtype)
+ fmax = finfo.max
+ scaling = fmax / quant_A.abs().amax(-1, keepdim=True)
+ quant_A *= scaling
+ quant_A = quant_A.to(out_dtype).to(torch.float32)
+
+ # create scale and A
+ scale = torch.rand(M, K // group_size, dtype=torch.float32, device='cuda')
+ scale /= fmax
+ A = quant_A * scale[..., None]
+
+ A = A.reshape(M, K)
+ quant_A = quant_A.reshape(M, K).to(out_dtype)
+ return A, quant_A, scale
+
+
+def _aligned_size(a, b):
+ return (a + b - 1) // b * b
+
+
+def _make_B(K, N, group_size, out_dtype):
+ K_aligned = _aligned_size(K, group_size)
+ N_aligned = _aligned_size(N, group_size)
+
+ quant_B = torch.rand(K_aligned // group_size,
+ group_size,
+ N_aligned // group_size,
+ group_size,
+ dtype=torch.float32,
+ device='cuda')
+ quant_B = quant_B * 2 - 1
+
+ # scaling abs max to fmax
+ finfo = torch.finfo(out_dtype)
+ fmax = finfo.max
+ scaling = fmax / quant_B.abs().amax((1, 3), keepdim=True)
+ quant_B *= scaling
+ quant_B = quant_B.to(out_dtype).to(torch.float32)
+
+ scale = torch.rand(K_aligned // group_size,
+ 1,
+ N_aligned // group_size,
+ 1,
+ dtype=torch.float32,
+ device='cuda')
+ scale /= fmax
+
+ B = quant_B * scale
+
+ B = B.reshape(K_aligned, N_aligned)[:K, :N]
+ quant_B = quant_B.reshape(K_aligned, N_aligned).to(out_dtype)[:K, :N]
+ scale = scale.reshape(K_aligned // group_size, N_aligned // group_size)
+ return B, quant_B, scale
+
+
+@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9,
+ reason='require device with cc>=9.0')
+class TestQuantFP8:
+
+ @pytest.fixture
+ def M(self):
+ yield 256
+
+ @pytest.fixture
+ def K(self):
+ yield 512
+
+ @pytest.fixture
+ def group_size(self):
+ yield 128
+
+ @pytest.fixture
+ def out_dtype(self):
+ yield torch.float8_e4m3fn
+
+ @pytest.fixture
+ def build_A(self, M, K, group_size, out_dtype):
+ return _make_A(M, K, group_size, out_dtype)
+
+ @pytest.fixture
+ def A(self, build_A):
+ return build_A[0]
+
+ @pytest.fixture
+ def quant_A(self, build_A):
+ return build_A[1]
+
+ @pytest.fixture
+ def scale(self, build_A):
+ return build_A[2]
+
+ @pytest.fixture
+ def gt(self, quant_A, scale):
+ yield quant_A, scale
+
+ def test_quant_fp8(self, A, group_size, out_dtype, gt):
+ from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
+ quant_A_gt, scale_gt = gt
+
+ quant_A, scale = quant_fp8(A, group_size=group_size, dtype=out_dtype)
+ torch.testing.assert_close(scale, scale_gt)
+ diff = (quant_A.to(torch.float16) - quant_A_gt.to(torch.float16)).abs()
+ diff_count = (diff > 1e-5).count_nonzero()
+ assert diff_count / diff.numel() < 1e-4
+
+
+@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9,
+ reason='require device with cc>=9.0')
+class TestGemmFP8:
+
+ @pytest.fixture
+ def M(self):
+ yield 256
+
+ @pytest.fixture
+ def N(self):
+ # test non-aligned
+ yield 1024 + 64
+
+ @pytest.fixture
+ def K(self):
+ yield 512
+
+ @pytest.fixture
+ def group_size(self):
+ yield 128
+
+ @pytest.fixture
+ def quant_dtype(self):
+ yield torch.float8_e4m3fn
+
+ @pytest.fixture
+ def out_dtype(self):
+ yield torch.float16
+
+ @pytest.fixture
+ def build_A(self, M, K, group_size, quant_dtype):
+ return _make_A(M, K, group_size, quant_dtype)
+
+ @pytest.fixture
+ def A(self, build_A, out_dtype):
+ return build_A[0].to(out_dtype)
+
+ @pytest.fixture
+ def quant_A(self, build_A):
+ return build_A[1]
+
+ @pytest.fixture
+ def scale_A(self, build_A):
+ return build_A[2]
+
+ @pytest.fixture
+ def build_B(self, K, N, group_size, quant_dtype):
+ return _make_B(K, N, group_size, quant_dtype)
+
+ @pytest.fixture
+ def B(self, build_B, out_dtype):
+ return build_B[0].to(out_dtype)
+
+ @pytest.fixture
+ def quant_B(self, build_B):
+ return build_B[1]
+
+ @pytest.fixture
+ def scale_B(self, build_B):
+ return build_B[2]
+
+ @pytest.fixture
+ def gt(self, A, B):
+ yield A @ B
+
+ def test_gemm_fp8(self, quant_A, scale_A, quant_B, scale_B, out_dtype, gt):
+ from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import \
+ blocked_gemm_fp8
+ C = blocked_gemm_fp8(quant_A,
+ scale_A,
+ quant_B,
+ scale_B,
+ out_dtype=out_dtype)
+ torch.testing.assert_close(C, gt, atol=0.5, rtol=1e-4)
diff --git a/tests/test_lmdeploy/test_model.py b/tests/test_lmdeploy/test_model.py
index 3b78053a74..0e53283a87 100644
--- a/tests/test_lmdeploy/test_model.py
+++ b/tests/test_lmdeploy/test_model.py
@@ -220,7 +220,7 @@ def test_llama3_1():
},
}]
actual_prompt = model.messages2prompt(messages, tools=tools)
- expected_prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 23 Jul 2024\n\n# Tool Instructions\n- Always execute python code in messages that you share.\n- When looking for real time information use relevant functions if available else fallback to brave_search\n\n\n\nYou have access to the following functions:\n\nUse the function \'spotify_trending_songs\' to: Get top trending songs on Spotify\n{"name": "spotify_trending_songs", "description": "Get top trending songs on Spotify", "parameters": {"n": {"param_type": "int", "description": "Number of trending songs to get", "required": true}}}\n\n\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => ` a JSON dict with the function argument name as key and function argument value as value.\nend_tag => ``\n\nHere is an example,\n{"example_name": "example_value"}\n\nReminder:\n- Function calls MUST follow the specified format\n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line"\n- Always add your sources when using search results to answer the user query\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCan you check the top 5 trending songs on spotify?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa
+ expected_prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n# Tool Instructions\n- Always execute python code in messages that you share.\n- When looking for real time information use relevant functions if available else fallback to brave_search\n\n\n\nYou have access to the following functions:\n\nUse the function \'spotify_trending_songs\' to: Get top trending songs on Spotify\n{"name": "spotify_trending_songs", "description": "Get top trending songs on Spotify", "parameters": {"n": {"param_type": "int", "description": "Number of trending songs to get", "required": true}}}\n\n\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => ` a JSON dict with the function argument name as key and function argument value as value.\nend_tag => ``\n\nHere is an example,\n{"example_name": "example_value"}\n\nReminder:\n- Function calls MUST follow the specified format\n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line"\n- Always add your sources when using search results to answer the user query\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCan you check the top 5 trending songs on spotify?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa
assert actual_prompt == expected_prompt
diff --git a/tests/test_lmdeploy/test_vl_encode.py b/tests/test_lmdeploy/test_vl/test_vl_encode.py
similarity index 100%
rename from tests/test_lmdeploy/test_vl_encode.py
rename to tests/test_lmdeploy/test_vl/test_vl_encode.py
diff --git a/tests/test_lmdeploy/test_vl_template.py b/tests/test_lmdeploy/test_vl_template.py
deleted file mode 100644
index cf8abf9e44..0000000000
--- a/tests/test_lmdeploy/test_vl_template.py
+++ /dev/null
@@ -1,132 +0,0 @@
-import PIL
-
-from lmdeploy.model import MODELS
-from lmdeploy.vl.constants import IMAGE_TOKEN
-from lmdeploy.vl.templates import VLChatTemplateWrapper
-
-
-def test_prompt_to_messages():
- model = MODELS.get('llava-v1')()
- templtae = VLChatTemplateWrapper(model)
- out = templtae.prompt_to_messages('hi')
- assert isinstance(out, list) and isinstance(out[0], dict)
- im = PIL.Image.new(mode='RGB', size=(200, 200))
- out = templtae.prompt_to_messages(('hi', [im]))
- assert isinstance(out, list) and isinstance(out[0], dict)
-
-
-def test_messages2prompt():
- model = MODELS.get('llava-v1')()
- templtae = VLChatTemplateWrapper(model)
- messages = [
- dict(role='user',
- content=[
- dict(type='text', text='q1'),
- dict(type='image_url', image_url=dict(url='xxx'))
- ])
- ]
- prompt = templtae.messages2prompt(messages)
- assert isinstance(prompt, str)
- assert prompt.count(IMAGE_TOKEN) == 1
- expected = (
- 'A chat between a curious human and an artificial intelligence '
- 'assistant. The assistant gives helpful, detailed, and polite '
- "answers to the human's questions. USER: "
- '\nq1 ASSISTANT:')
- assert prompt == expected
-
- messages.append({'role': 'assistant', 'content': 'a1'})
- messages.append({'role': 'user', 'content': 'q2'})
- prompt = templtae.messages2prompt(messages)
- expected = (
- 'A chat between a curious human and an artificial intelligence '
- 'assistant. The assistant gives helpful, detailed, and polite '
- "answers to the human's questions. USER: "
- '\nq1 ASSISTANT: a1USER: q2 ASSISTANT:')
- assert prompt == expected
-
-
-def test_internvl2_conv():
- # https://huggingface.co/OpenGVLab/InternVL2-8B/blob/3bfd3664dea4f3da628785f5125d30f889701253/conversation.py
- from transformers.dynamic_module_utils import get_class_from_dynamic_module
- get_conv_template = get_class_from_dynamic_module(
- 'conversation.get_conv_template', 'OpenGVLab/InternVL2-8B')
- template = get_conv_template('internlm2-chat')
- question1 = 'question1'
- template.append_message(template.roles[0], question1)
- template.append_message(template.roles[1], None)
- model = MODELS.get('internvl2-internlm2')()
- messages = [dict(role='user', content=question1)]
- assert template.get_prompt() == model.messages2prompt(messages)
-
- answer1 = 'answer1'
- template.messages[-1][1] = answer1
- question2 = 'question2'
- template.append_message(template.roles[0], question2)
- template.append_message(template.roles[1], None)
- messages.append(dict(role='assistant', content=answer1))
- messages.append(dict(role='user', content=question2))
- assert template.get_prompt() == model.messages2prompt(messages)
-
-
-def test_llava_conv_chatml_direct():
- model = MODELS.get('llava-chatml')()
- templtae = VLChatTemplateWrapper(model)
- messages = [
- dict(role='user',
- content=[
- dict(type='text', text='q1'),
- dict(type='image_url', image_url=dict(url='xxx'))
- ])
- ]
-
- prompt = templtae.messages2prompt(messages)
- expected = ('<|im_start|>system\nAnswer the questions.<|im_end|>'
- '<|im_start|>user\n\nq1<|im_end|>'
- '<|im_start|>assistant\n')
- assert prompt == expected
-
- messages.append({'role': 'assistant', 'content': 'a1'})
- messages.append({'role': 'user', 'content': 'q2'})
- prompt = templtae.messages2prompt(messages)
- expected = ('<|im_start|>system\nAnswer the questions.<|im_end|>'
- '<|im_start|>user\n\nq1<|im_end|>'
- '<|im_start|>assistant\na1<|im_end|>'
- '<|im_start|>user\nq2<|im_end|>'
- '<|im_start|>assistant\n')
- assert prompt == expected
-
-
-def test_custom_image_token():
- from lmdeploy.vl.templates import DeepSeekVLChatTemplateWrapper
- model = MODELS.get('deepseek-vl')()
- template = DeepSeekVLChatTemplateWrapper(model)
-
- def create_user(query: str):
- item = dict(role='user', content=[dict(type='text', text=query)])
- num = query.count(IMAGE_TOKEN)
- for _ in range(num):
- item['content'].append(
- dict(type='image_url', image_url=dict(url='xxx')))
- return item
-
- def create_assistant(response: str):
- return dict(role='assistant', content=response)
-
- messages = [create_user(f'{IMAGE_TOKEN} q1')]
- prompt = template.messages2prompt(messages)
- expected = ('You are a helpful language and vision assistant. You are able'
- ' to understand the visual content that the user provides, and'
- ' assist the user with a variety of tasks using natural '
- 'language.\n\nUser: q1\n\nAssistant:')
- assert prompt == expected
-
- messages.append(create_assistant('a1'))
- messages.append(create_user(f'q2 {IMAGE_TOKEN}'))
- prompt = template.messages2prompt(messages)
- expected = ('You are a helpful language and vision assistant. You are able'
- ' to understand the visual content that the user provides, and'
- ' assist the user with a variety of tasks using natural '
- 'language.\n\nUser: q1\n\nAssistant: '
- 'a1<|end▁of▁sentence|>User: q2 \n\nAssistant:')
- assert prompt == expected