diff --git a/.github/scripts/eval_chat_config.py b/.github/scripts/eval_chat_config.py index e2463c0f39..74ae7a8968 100644 --- a/.github/scripts/eval_chat_config.py +++ b/.github/scripts/eval_chat_config.py @@ -1,7 +1,7 @@ from copy import deepcopy from mmengine.config import read_base -from opencompass.models import TurboMindModel, TurboMindModelwithChatTemplate +from opencompass.models import TurboMindModelwithChatTemplate with read_base(): # choose a list of datasets @@ -84,6 +84,8 @@ models as hf_mistral_chat_7b # noqa: F401, E501 from opencompass.configs.models.mistral.hf_mixtral_8x7b_instruct_v0_1 import \ models as hf_mixtral_chat_8x7b # noqa: F401, E501 + from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import \ + models as lmdeploy_qwen2_5_7b_instruct # noqa: F401, E501 from opencompass.configs.models.qwen.hf_qwen1_5_7b_chat import \ models as hf_qwen1_5_chat_7b # noqa: F401, E501 from opencompass.configs.models.qwen.hf_qwen1_5_moe_a2_7b_chat import \ @@ -146,10 +148,8 @@ turbomind_internlm2_5_7b_chat_4bits = deepcopy(*lmdeploy_internlm2_5_7b_chat) turbomind_internlm2_5_7b_chat_kvint4 = deepcopy(*lmdeploy_internlm2_5_7b_chat) turbomind_internlm2_5_7b_chat_kvint8 = deepcopy(*lmdeploy_internlm2_5_7b_chat) -turbomind_internlm2_5_7b_chat_batch1 = deepcopy(*lmdeploy_internlm2_5_7b_chat) -turbomind_internlm2_5_7b_chat_batch1_4bits = deepcopy( - *lmdeploy_internlm2_5_7b_chat) pytorch_internlm2_5_7b_chat = deepcopy(*lmdeploy_internlm2_5_7b_chat) +pytorch_internlm2_5_7b_chat_w8a8 = deepcopy(*lmdeploy_internlm2_5_7b_chat) # ===== Configs for internlm/internlm2_5_20b_chat ===== turbomind_internlm2_5_20b_chat = deepcopy(*lmdeploy_internlm2_5_20b_chat) @@ -181,26 +181,6 @@ turbomind_qwen_7b_chat_kvint8 = deepcopy(*lmdeploy_qwen_7b_chat) pytorch_qwen_7b_chat = deepcopy(*lmdeploy_qwen_7b_chat) -# ===== Configs for meta-llama/Llama-2-7b-chat-hf ===== -turbomind_llama2_7b_chat = dict(type=TurboMindModel, - abbr='tb_llama2_chat_7b', - path='meta-llama/Llama-2-7b-chat-hf', - engine_config=dict(session_len=MAX_SESSION_LEN, - max_batch_size=128), - gen_config=dict(top_k=1, - top_p=0.8, - temperature=1.0, - max_new_tokens=MAX_NEW_TOKENS), - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=128, - meta_template=llama2_meta_template, - run_cfg=dict(num_gpus=1), - end_str='[INST]') -turbomind_llama2_7b_chat_4bits = deepcopy(turbomind_llama2_7b_chat) -turbomind_llama2_7b_chat_kvint4 = deepcopy(turbomind_llama2_7b_chat) -turbomind_llama2_7b_chat_kvint8 = deepcopy(turbomind_llama2_7b_chat) - # ===== Configs for meta-llama/Meta-Llama-3-8B-Instruct ===== turbomind_llama3_8b_instruct = deepcopy(*lmdeploy_llama3_8b_instruct) turbomind_llama3_8b_instruct_4bits = deepcopy(*lmdeploy_llama3_8b_instruct) @@ -218,6 +198,7 @@ turbomind_llama3_1_8b_instruct_kvint8 = deepcopy( turbomind_llama3_1_8b_instruct) pytorch_llama3_1_8b_instruct = deepcopy(turbomind_llama3_1_8b_instruct) +pytorch_llama3_1_8b_instruct_w8a8 = deepcopy(turbomind_llama3_1_8b_instruct) # ===== Configs for Qwen/Qwen2-7B-Instruct ===== turbomind_qwen2_7b_instruct = deepcopy(*lmdeploy_qwen2_7b_instruct) @@ -225,17 +206,36 @@ turbomind_qwen2_7b_instruct_kvint4 = deepcopy(*lmdeploy_qwen2_7b_instruct) turbomind_qwen2_7b_instruct_kvint8 = deepcopy(*lmdeploy_qwen2_7b_instruct) pytorch_qwen2_7b_instruct = deepcopy(*lmdeploy_qwen2_7b_instruct) +pytorch_qwen2_7b_instruct_w8a8 = deepcopy(*lmdeploy_qwen2_7b_instruct) + +# ===== Configs for Qwen/Qwen25-7B-Instruct ===== +turbomind_qwen2_5_7b_instruct = deepcopy(*lmdeploy_qwen2_5_7b_instruct) +turbomind_qwen2_5_7b_instruct_4bits = deepcopy(*lmdeploy_qwen2_5_7b_instruct) +turbomind_qwen2_5_7b_instruct_kvint4 = deepcopy(*lmdeploy_qwen2_5_7b_instruct) +turbomind_qwen2_5_7b_instruct_kvint8 = deepcopy(*lmdeploy_qwen2_5_7b_instruct) +pytorch_qwen2_5_7b_instruct = deepcopy(*lmdeploy_qwen2_5_7b_instruct) +pytorch_qwen2_5_7b_instruct_w8a8 = deepcopy(*lmdeploy_qwen2_5_7b_instruct) + +# ===== Configs for meta-llama/Llama-2-7b-chat-hf ===== +turbomind_llama2_7b_chat = deepcopy(*lmdeploy_llama2_7b_chat) +turbomind_llama2_7b_chat_4bits = deepcopy(*lmdeploy_llama2_7b_chat) +turbomind_llama2_7b_chat_kvint4 = deepcopy(*lmdeploy_llama2_7b_chat) +turbomind_llama2_7b_chat_kvint8 = deepcopy(*lmdeploy_llama2_7b_chat) for model in [v for k, v in locals().items() if k.startswith('turbomind_')]: - model['engine_config']['max_batch_size'] = 128 + model['engine_config']['max_batch_size'] = 1 model['gen_config']['do_sample'] = False - model['batch_size'] = 128 + model['batch_size'] = 100 for model in [v for k, v in locals().items() if k.endswith('_4bits')]: model['engine_config']['model_format'] = 'awq' model['abbr'] = model['abbr'] + '_4bits' model['path'] = model['path'] + '-inner-4bits' +for model in [v for k, v in locals().items() if k.endswith('_w8a8')]: + model['abbr'] = model['abbr'] + '_w8a8' + model['path'] = model['path'] + '-inner-w8a8' + for model in [v for k, v in locals().items() if k.endswith('_kvint4')]: model['engine_config']['quant_policy'] = 4 model['abbr'] = model['abbr'] + '_kvint4' @@ -247,24 +247,19 @@ for model in [v for k, v in locals().items() if k.startswith('pytorch_')]: model['abbr'] = model['abbr'].replace('turbomind', 'pytorch') model['backend'] = 'pytorch' - model['engine_config']['max_batch_size'] = 64 - model['gen_config']['do_sample'] = False - model['batch_size'] = 64 - -for model in [v for k, v in locals().items() if '_batch1' in k]: - model['abbr'] = model['abbr'] + '_batch1' model['engine_config']['max_batch_size'] = 1 - model['batch_size'] = 1 + model['gen_config']['do_sample'] = False + model['batch_size'] = 100 basic_pytorch_chat_tp1 = dict(type=TurboMindModelwithChatTemplate, engine_config=dict(session_len=MAX_SESSION_LEN, - max_batch_size=64, + max_batch_size=1, tp=1), gen_config=dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS), max_out_len=MAX_NEW_TOKENS, max_seq_len=MAX_SESSION_LEN, - batch_size=64, + batch_size=100, run_cfg=dict(num_gpus=1)) # ===== Configs for Qwen/Qwen1.5-MoE-A2.7B-Chat ===== @@ -277,6 +272,13 @@ pytorch_gemma_2_9b_it['abbr'] = 'pytorch_gemma_2_9b_it' pytorch_gemma_2_9b_it['path'] = 'google/gemma-2-9b-it' +# ===== Configs for google/gemma2-27b-it ===== +pytorch_gemma_2_27b_it = deepcopy(basic_pytorch_chat_tp1) +pytorch_gemma_2_27b_it['abbr'] = 'pytorch_gemma_2_27b_it' +pytorch_gemma_2_27b_it['path'] = 'google/gemma-2-27b-it' +pytorch_gemma_2_27b_it['run_cfg']['num_gpus'] = 2 +pytorch_gemma_2_27b_it['engine_config']['tp'] = 2 + race_datasets = [race_datasets[1]] # Summarizer diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index bd3876f9ed..e75f728783 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -88,7 +88,7 @@ jobs: - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Clone repository - uses: actions/checkout@v3 + uses: actions/checkout@v2 if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} @@ -105,10 +105,8 @@ jobs: run: | # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl - python3 -m pip install -e /root/packages/AutoAWQ_kernels - python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps - python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps + python3 -m pip install /root/packages/flash_attn-*.whl + python3 -m pip install /root/packages/xformers-*.whl --no-deps python3 -m pip install -r /nvme/qa_test_models/offline_pkg/requirements.txt - name: Install lmdeploy if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} @@ -148,9 +146,15 @@ jobs: needs: [benchmark] timeout-minutes: 5 runs-on: [self-hosted, linux-a100] + container: + image: openmmlab/lmdeploy:latest-cu11 + options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" + volumes: + - /nvme/qa_test_models:/nvme/qa_test_models + - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Clone repository - uses: actions/checkout@v3 + uses: actions/checkout@v2 with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} diff --git a/.github/workflows/daily_ete_test.yml b/.github/workflows/daily_ete_test.yml index dbacfc32f5..d6299e163a 100644 --- a/.github/workflows/daily_ete_test.yml +++ b/.github/workflows/daily_ete_test.yml @@ -130,7 +130,7 @@ jobs: needs: download_pkgs if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'quant') )}} runs-on: [self-hosted, linux-a100] - timeout-minutes: 120 + timeout-minutes: 150 env: PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA MODELSCOPE_CACHE: /root/modelscope_hub @@ -149,15 +149,14 @@ jobs: - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts - run: cp -r ${{env.TEST_CODE_PATH}}/. . + run: | + cp -r ${{env.TEST_CODE_PATH}}/. . - name: Install lmdeploy - dependency run: | # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl - python3 -m pip install -e /root/packages/AutoAWQ_kernels - python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps - python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps + python3 -m pip install /root/packages/flash_attn-*.whl + python3 -m pip install /root/packages/xformers-*.whl --no-deps python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | @@ -166,7 +165,6 @@ jobs: pip install ${{env.DEEPSEEK_VL}} --no-deps - name: Check env run: | - pip install transformers pip uninstall -y nvidia-nccl-cu11 python3 -m pip list lmdeploy check_env @@ -244,20 +242,20 @@ jobs: - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts - run: cp -r ${{env.TEST_CODE_PATH}}/. . + run: | + cp -r ${{env.TEST_CODE_PATH}}/. . - name: Install lmdeploy - dependency run: | # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl - python3 -m pip install -e /root/packages/AutoAWQ_kernels - python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps - python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps + python3 -m pip install /root/packages/flash_attn-*.whl + python3 -m pip install /root/packages/xformers-*.whl --no-deps python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt + rm -rf ${{env.DEEPSEEK_VL}}/build pip install ${{env.DEEPSEEK_VL}} --no-deps - name: Check env run: | @@ -286,6 +284,8 @@ jobs: mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') + pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true + mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - pipeline continue-on-error: true if: matrix.function == 'pipeline' @@ -294,6 +294,8 @@ jobs: mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') + pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true + mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - restful continue-on-error: true if: matrix.function == 'restful' @@ -302,6 +304,8 @@ jobs: mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') + pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true + mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - restful workspace continue-on-error: true if: matrix.backend == 'turbomind' && matrix.model == 'llm' && matrix.function == 'restful' @@ -310,6 +314,8 @@ jobs: mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true pytest autotest/tools/restful/test_restful_chat_workspace.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') + pytest autotest/tools/restful/test_restful_chat_workspace.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true + mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - local testcase if: matrix.backend == 'turbomind' && matrix.model == 'llm' && matrix.function == 'local_case' run: | @@ -344,15 +350,14 @@ jobs: - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts - run: cp -r ${{env.TEST_CODE_PATH}}/. . + run: | + cp -r ${{env.TEST_CODE_PATH}}/. . - name: Install lmdeploy - dependency run: | # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl - python3 -m pip install -e /root/packages/AutoAWQ_kernels - python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps - python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps + python3 -m pip install /root/packages/flash_attn-*.whl + python3 -m pip install /root/packages/xformers-*.whl --no-deps python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | @@ -436,15 +441,14 @@ jobs: - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts - run: cp -r ${{env.TEST_CODE_PATH}}/. . + run: | + cp -r ${{env.TEST_CODE_PATH}}/. . - name: Install lmdeploy - dependency run: | # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl - python3 -m pip install -e /root/packages/AutoAWQ_kernels - python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps - python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps + python3 -m pip install /root/packages/flash_attn-*.whl + python3 -m pip install /root/packages/xformers-*.whl --no-deps python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | @@ -497,15 +501,14 @@ jobs: - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts - run: cp -r ${{env.TEST_CODE_PATH}}/. . + run: | + cp -r ${{env.TEST_CODE_PATH}}/. . - name: Install lmdeploy - dependency run: | # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl - python3 -m pip install -e /root/packages/AutoAWQ_kernels - python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps - python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps + python3 -m pip install /root/packages/flash_attn-*.whl + python3 -m pip install /root/packages/xformers-*.whl --no-deps python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | @@ -560,15 +563,14 @@ jobs: - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts - run: cp -r ${{env.TEST_CODE_PATH}}/. . + run: | + cp -r ${{env.TEST_CODE_PATH}}/. . - name: Install lmdeploy - dependency run: | # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl - python3 -m pip install -e /root/packages/AutoAWQ_kernels - python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps - python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps + python3 -m pip install /root/packages/flash_attn-*.whl + python3 -m pip install /root/packages/xformers-*.whl --no-deps python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | @@ -600,7 +602,7 @@ jobs: run: | export LMDEPLOY_DIR=$(pwd) - python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, turbomind_internlm2_5_7b_chat_batch1, turbomind_internlm2_5_7b_chat_batch1_4bits, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it]" "[*race_datasets, *gsm8k_datasets, *ifeval_datasets]" /root/evaluation-reports/${{ github.run_id }} chat true + python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat_w8a8, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct_w8a8, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, turbomind_qwen2_5_7b_instruct, pytorch_qwen2_5_7b_instruct, pytorch_qwen2_5_7b_instruct_w8a8, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it, pytorch_gemma_2_27b_it]" "[*race_datasets, *gsm8k_datasets, *ifeval_datasets]" /root/evaluation-reports/${{ github.run_id }} chat true - name: Evaluate base models if: matrix.evaluate_type == 'base' run: | @@ -622,11 +624,17 @@ jobs: needs: [test_benchmark] timeout-minutes: 5 runs-on: [self-hosted, linux-a100] + container: + image: openmmlab/lmdeploy:latest-cu11 + options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" + volumes: + - /nvme/qa_test_models:/nvme/qa_test_models + - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro env: BENCHMARK_REPORT_DIR: /nvme/qa_test_models/benchmark-reports/${{ github.run_id }} steps: - name: Clone repository - uses: actions/checkout@v3 + uses: actions/checkout@v2 with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} diff --git a/.github/workflows/daily_ete_test_v100.yml b/.github/workflows/daily_ete_test_v100.yml index 8a662b85f5..343cfdea50 100644 --- a/.github/workflows/daily_ete_test_v100.yml +++ b/.github/workflows/daily_ete_test_v100.yml @@ -158,8 +158,7 @@ jobs: run: | # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps - python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps + python3 -m pip install /root/packages/xformers-*.whl --no-deps python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | @@ -167,7 +166,6 @@ jobs: python3 -m pip install -r requirements/test.txt - name: Check env run: | - pip install triton==3.0.0 pip uninstall -y nvidia-nccl-cu11 python3 -m pip list lmdeploy check_env @@ -245,8 +243,7 @@ jobs: run: | # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps - python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps + python3 -m pip install /root/packages/xformers-*.whl --no-deps python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | @@ -254,7 +251,6 @@ jobs: python3 -m pip install -r requirements/test.txt - name: Check env run: | - pip install triton==3.0.0 pip uninstall -y nvidia-nccl-cu11 python3 -m pip list lmdeploy check_env @@ -345,8 +341,7 @@ jobs: run: | # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps - python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps + python3 -m pip install /root/packages/xformers-*.whl --no-deps python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | @@ -354,7 +349,6 @@ jobs: python3 -m pip install -r requirements/test.txt - name: Check env run: | - pip install triton==3.0.0 pip uninstall -y nvidia-nccl-cu11 python3 -m pip list lmdeploy check_env @@ -437,8 +431,7 @@ jobs: run: | # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps - python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps + python3 -m pip install /root/packages/xformers-*.whl --no-deps python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | @@ -446,7 +439,6 @@ jobs: python3 -m pip install -r requirements/test.txt - name: Check env run: | - pip install triton==3.0.0 pip uninstall -y nvidia-nccl-cu11 python3 -m pip list lmdeploy check_env @@ -498,8 +490,7 @@ jobs: run: | # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps - python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps + python3 -m pip install /root/packages/xformers-*.whl --no-deps python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | @@ -507,7 +498,6 @@ jobs: python3 -m pip install -r requirements/test.txt - name: Check env run: | - pip install triton==3.0.0 pip uninstall -y nvidia-nccl-cu11 python3 -m pip list lmdeploy check_env @@ -560,8 +550,7 @@ jobs: run: | # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps - python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps + python3 -m pip install /root/packages/xformers-*.whl --no-deps python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | @@ -575,7 +564,6 @@ jobs: echo "OPENCOMPASS_DIR=$(pwd)" >> $GITHUB_ENV - name: Check env run: | - pip install triton==3.0.0 pip uninstall -y nvidia-nccl-cu11 python3 -m pip list lmdeploy check_env @@ -593,13 +581,13 @@ jobs: run: | export LMDEPLOY_DIR=$(pwd) - python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, turbomind_internlm2_5_7b_chat_batch1, turbomind_internlm2_5_7b_chat_batch1_4bits, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it]" "[*race_datasets, *gsm8k_datasets, *ifeval_datasets]" /root/evaluation-reports/${{ github.run_id }} chat true + python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it]" "[*race_datasets, *gsm8k_datasets, *ifeval_datasets]" /root/evaluation-reports/${{ github.run_id }} chat true - name: Evaluate base models if: matrix.evaluate_type == 'base' run: | export LMDEPLOY_DIR=$(pwd) - python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_5_7b, turbomind_qwen2_5_14b, turbomind_internlm2_5_7b_batch1]" "[*race_datasets, *gsm8k_datasets, *gpqa_datasets, *winogrande_datasets]" /root/evaluation-reports/${{ github.run_id }} base true + python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_5_7b, turbomind_qwen2_5_14b]" "[*race_datasets, *gsm8k_datasets, *gpqa_datasets, *winogrande_datasets]" /root/evaluation-reports/${{ github.run_id }} base true - name: Clear workspace if: always() run: | diff --git a/.github/workflows/evaluate.yml b/.github/workflows/evaluate.yml index b6ab89f595..dbfff04fe2 100644 --- a/.github/workflows/evaluate.yml +++ b/.github/workflows/evaluate.yml @@ -17,7 +17,7 @@ on: required: true description: 'Tested TurboMind models list. eg. [internlm_chat_7b,internlm_chat_7b_w8a16]' type: string - default: '[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, turbomind_internlm2_5_7b_chat_batch1, turbomind_internlm2_5_7b_chat_batch1_4bits, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it, turbomind_internlm2_chat_7b_4bits, turbomind_internlm2_chat_7b_kvint4, turbomind_internlm2_chat_7b_kvint8, turbomind_internlm2_5_7b_chat_4bits, turbomind_internlm2_5_7b_chat_kvint4, turbomind_internlm2_5_7b_chat_kvint8, turbomind_internlm2_5_20b_chat_4bits, turbomind_internlm2_5_20b_chat_kvint4, turbomind_internlm2_5_20b_chat_kvint8, turbomind_qwen1_5_7b_chat_4bits, turbomind_qwen1_5_7b_chat_kvint4, turbomind_qwen1_5_7b_chat_kvint8, turbomind_llama2_7b_chat_4bits, turbomind_llama2_7b_chat_kvint4, turbomind_llama2_7b_chat_kvint8, turbomind_llama3_8b_instruct_4bits, turbomind_llama3_8b_instruct_kvint4, turbomind_llama3_8b_instruct_kvint8, turbomind_llama3_1_8b_instruct_4bits, turbomind_llama3_1_8b_instruct_kvint4, turbomind_llama3_1_8b_instruct_kvint8, turbomind_qwen2_7b_instruct_4bits, turbomind_qwen2_7b_instruct_kvint8]' + default: '[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, turbomind_qwen2_5_7b_instruct, pytorch_qwen2_5_7b_instruct, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it, pytorch_gemma_2_27b_it, turbomind_internlm2_chat_7b_kvint4, turbomind_internlm2_chat_7b_kvint8, turbomind_internlm2_5_7b_chat_4bits, turbomind_internlm2_5_7b_chat_kvint4, turbomind_internlm2_5_7b_chat_kvint8, pytorch_internlm2_5_7b_chat_w8a8, turbomind_internlm2_5_20b_chat_4bits, turbomind_internlm2_5_20b_chat_kvint4, turbomind_internlm2_5_20b_chat_kvint8, turbomind_qwen1_5_7b_chat_4bits, turbomind_qwen1_5_7b_chat_kvint4, turbomind_qwen1_5_7b_chat_kvint8, turbomind_llama2_7b_chat_4bits, turbomind_llama2_7b_chat_kvint4, turbomind_llama2_7b_chat_kvint8, turbomind_llama3_8b_instruct_4bits, turbomind_llama3_8b_instruct_kvint4, turbomind_llama3_8b_instruct_kvint8, turbomind_llama3_1_8b_instruct_4bits, turbomind_llama3_1_8b_instruct_kvint4, turbomind_llama3_1_8b_instruct_kvint8, pytorch_llama3_1_8b_instruct_w8a8, turbomind_qwen2_7b_instruct_4bits, turbomind_qwen2_7b_instruct_kvint8, turbomind_qwen2_5_7b_instruct_4bits, turbomind_qwen2_5_7b_instruct_kvint8, pytorch_qwen2_5_7b_instruct_w8a8]' chat_datasets: required: true description: 'Tested datasets list. eg. [*bbh_datasets,*ceval_datasets,*cmmlu_datasets,*GaokaoBench_datasets,*gpqa_datasets,*gsm8k_datasets,*hellaswag_datasets,*humaneval_datasets,*ifeval_datasets,*math_datasets,*sanitized_mbpp_datasets,*mmlu_datasets,*nq_datasets,*race_datasets,*TheoremQA_datasets,*triviaqa_datasets,*winogrande_datasets,*crowspairs_datasets]' @@ -25,7 +25,7 @@ on: default: '[*mmlu_datasets, *gsm8k_datasets, *ifeval_datasets]' base_models: required: true - description: 'Tested TurboMind models list. eg. [turbomind_internlm2_5_7b, turbomind_qwen2_7b, turbomind_internlm2_5_7b_batch1]' + description: 'Tested TurboMind models list. eg. [turbomind_internlm2_5_7b, turbomind_qwen2_7b]' type: string default: '[turbomind_internlm2_5_7b, turbomind_internlm2_5_7b_4bits, turbomind_internlm2_5_7b_batch1, turbomind_internlm2_5_7b_batch1_4bits, turbomind_qwen2_7b, turbomind_qwen2_5_7b, turbomind_qwen2_5_14b]' baes_datasets: @@ -133,8 +133,10 @@ jobs: run: | # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl - python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps + python3 -m pip install /root/packages/flash_attn-*.whl + python3 -m pip install -e /root/packages/AutoAWQ_kernels + python3 -m pip install /root/packages/autoawq-*.whl --no-deps + python3 -m pip install /root/packages/xformers-*.whl --no-deps python3 -m pip install -r /root/models/offline_pkg/requirements.txt - name: Install lmdeploy if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} diff --git a/.github/workflows/pr_ete_test.yml b/.github/workflows/pr_ete_test.yml index 3a19ebe870..2d1c4b63f5 100644 --- a/.github/workflows/pr_ete_test.yml +++ b/.github/workflows/pr_ete_test.yml @@ -10,7 +10,7 @@ on: - "3rdparty/**" - "lmdeploy/**" - "requirements/**" - - "requirements.txt" + - "requirements_cuda.txt" - "CMakeLists.txt" - "setup.py" workflow_dispatch: @@ -68,7 +68,7 @@ jobs: export PATH=$PATH:/usr/local/openmpi/bin export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/openmpi/lib python3 -m pip install cmake packaging wheel transformers_stream_generator transformers datasets openai einops timm decord - python3 -m pip install -r requirements.txt -r requirements/test.txt -r requirements/build.txt + python3 -m pip install -r requirements_cuda.txt -r requirements/test.txt -r requirements/build.txt mkdir -p build && cd build &&\ sh ../generate.sh &&\ ninja -j$(nproc) && ninja install &&\ diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index ec6db0682d..3a459050ec 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -10,7 +10,7 @@ on: - "3rdparty/**" - "lmdeploy/**" - "requirements/**" - - "requirements.txt" + - "requirements_cuda.txt" - "CMakeLists.txt" - "setup.py" push: @@ -24,7 +24,7 @@ on: - "3rdparty/**" - "lmdeploy/**" - "requirements/**" - - "requirements.txt" + - "requirements_cuda.txt" - "CMakeLists.txt" - "setup.py" tags: @@ -39,6 +39,7 @@ jobs: options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e CUDA_VISIBLE_DEVICES=2,3 --pull never" volumes: - /nvme/share_data/github-actions/pip-cache:/root/.cache/pip + - /nvme/share_data/github-actions/hf_home:/root/.cache/huggingface - /nvme/share_data/github-actions/packages:/root/packages - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: @@ -78,7 +79,7 @@ jobs: python3 -m pip install pynvml packaging protobuf transformers_stream_generator # manually install flash attn python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp38-cp38-linux_x86_64.whl - python3 -m pip install -r requirements.txt -r requirements/test.txt + python3 -m pip install -r requirements_cuda.txt -r requirements/test.txt python3 -m pip install . - name: Check env run: | diff --git a/README.md b/README.md index d160338aa6..8ef7b7994f 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,8 @@ For detailed inference benchmarks in more devices and more settings, please refe
  • Qwen1.5 (0.5B - 110B)
  • Qwen1.5 - MoE (0.5B - 72B)
  • Qwen2 (0.5B - 72B)
  • +
  • Qwen2-MoE (57BA14B)
  • +
  • Qwen2.5 (0.5B - 32B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • @@ -136,6 +138,7 @@ For detailed inference benchmarks in more devices and more settings, please refe
  • Mistral (7B)
  • DeepSeek-MoE (16B)
  • DeepSeek-V2 (16B, 236B)
  • +
  • DeepSeek-V2.5 (236B)
  • Mixtral (8x7B, 8x22B)
  • Gemma (2B - 7B)
  • Dbrx (132B)
  • diff --git a/README_ja.md b/README_ja.md index fda176229e..77badaac36 100644 --- a/README_ja.md +++ b/README_ja.md @@ -122,6 +122,8 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
  • Qwen1.5 (0.5B - 110B)
  • Qwen1.5 - MoE (0.5B - 72B)
  • Qwen2 (0.5B - 72B)
  • +
  • Qwen2-MoE (57BA14B)
  • +
  • Qwen2.5 (0.5B - 32B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • @@ -133,6 +135,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
  • Mistral (7B)
  • DeepSeek-MoE (16B)
  • DeepSeek-V2 (16B, 236B)
  • +
  • DeepSeek-V2.5 (236B)
  • Mixtral (8x7B, 8x22B)
  • Gemma (2B - 7B)
  • Dbrx (132B)
  • diff --git a/README_zh-CN.md b/README_zh-CN.md index 6c24b2e500..9f3cd40a64 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -126,6 +126,8 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
  • Qwen1.5 (0.5B - 110B)
  • Qwen1.5 - MoE (0.5B - 72B)
  • Qwen2 (0.5B - 72B)
  • +
  • Qwen2-MoE (57BA14B)
  • +
  • Qwen2.5 (0.5B - 32B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • @@ -137,6 +139,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
  • Mistral (7B)
  • DeepSeek-MoE (16B)
  • DeepSeek-V2 (16B, 236B)
  • +
  • DeepSeek-V2.5 (236B)
  • Mixtral (8x7B, 8x22B)
  • Gemma (2B - 7B)
  • Dbrx (132B)
  • diff --git a/autotest/config.yaml b/autotest/config.yaml index b4fd4e1712..d92e32a595 100644 --- a/autotest/config.yaml +++ b/autotest/config.yaml @@ -17,15 +17,21 @@ tp_config: Meta-Llama-3-1-70B-Instruct: 4 internlm2_5-7b-chat-1m: 4 Qwen2-7B-Instruct-GPTQ-Int4: 2 - InternVL2-40B: 2 + InternVL2-26B: 2 + InternVL2-40B: 4 + InternVL2_5-26B: 2 + InternVL2_5-38B: 4 MiniCPM-V-2_6: 2 Qwen2.5-72B-Instruct: 4 + gemma-2-27b-it: 2 + DeepSeek-V2-Lite-Chat: 2 turbomind_chat_model: - meta-llama/Llama-3.2-1B-Instruct - meta-llama/Llama-3.2-3B-Instruct - meta-llama/Meta-Llama-3-1-8B-Instruct - meta-llama/Meta-Llama-3-1-8B-Instruct-AWQ + - meta-llama/Meta-Llama-3-1-70B-Instruct - meta-llama/Meta-Llama-3-8B-Instruct - meta-llama/Llama-2-7b-chat-hf - internlm/internlm2_5-7b-chat @@ -35,6 +41,10 @@ turbomind_chat_model: - internlm/internlm-chat-20b - internlm/internlm-xcomposer2-4khd-7b - internlm/internlm-xcomposer2d5-7b + - OpenGVLab/InternVL2_5-1B + - OpenGVLab/InternVL2_5-8B + - OpenGVLab/InternVL2_5-26B + - OpenGVLab/InternVL2_5-38B - OpenGVLab/InternVL2-1B - OpenGVLab/InternVL2-2B - OpenGVLab/InternVL2-8B @@ -42,6 +52,7 @@ turbomind_chat_model: - OpenGVLab/InternVL2-40B - OpenGVLab/InternVL-Chat-V1-5 - OpenGVLab/Mini-InternVL-Chat-2B-V1-5 + - OpenGVLab/InternVL2-Llama3-76B-AWQ - Qwen/Qwen2-7B-Instruct - Qwen/Qwen2-7B-Instruct-AWQ - Qwen/Qwen2-1.5B-Instruct @@ -51,6 +62,7 @@ turbomind_chat_model: - Qwen/Qwen-VL-Chat - Qwen/Qwen2.5-0.5B-Instruct - Qwen/Qwen2.5-7B-Instruct + - Qwen/Qwen2.5-72B-Instruct - Qwen/Qwen2-7B-Instruct-GPTQ-Int4 - Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4 - mistralai/Mistral-7B-Instruct-v0.3 @@ -69,10 +81,12 @@ turbomind_chat_model: - THUDM/glm-4-9b-chat - openbmb/MiniCPM-Llama3-V-2_5 - openbmb/MiniCPM-V-2_6 + - allenai/Molmo-7B-D-0924 pytorch_chat_model: - meta-llama/Meta-Llama-3-8B-Instruct - meta-llama/Meta-Llama-3-1-8B-Instruct + - meta-llama/Meta-Llama-3-1-70B-Instruct - meta-llama/Llama-3.2-1B-Instruct - meta-llama/Llama-3.2-3B-Instruct - meta-llama/Llama-3.2-11B-Vision-Instruct @@ -81,6 +95,10 @@ pytorch_chat_model: - internlm/internlm2_5-20b-chat - internlm/internlm2-chat-20b - internlm/internlm-chat-20b + - OpenGVLab/InternVL2_5-1B + - OpenGVLab/InternVL2_5-8B + - OpenGVLab/InternVL2_5-26B + - OpenGVLab/InternVL2_5-38B - OpenGVLab/InternVL2-1B - OpenGVLab/InternVL2-2B - OpenGVLab/InternVL2-4B @@ -92,10 +110,11 @@ pytorch_chat_model: - baichuan-inc/Baichuan2-7B-Chat - baichuan-inc/Baichuan2-13B-Chat - 01-ai/Yi-6B-Chat - - liuhaotian/llava-v1.5-13b - - liuhaotian/llava-v1.6-vicuna-7b - Qwen/Qwen2-7B-Instruct - Qwen/Qwen2-1.5B-Instruct + - Qwen/Qwen2.5-0.5B-Instruct + - Qwen/Qwen2.5-7B-Instruct + - Qwen/Qwen2.5-72B-Instruct - Qwen/Qwen1.5-7B-Chat - Qwen/Qwen1.5-MoE-A2.7B-Chat - Qwen/Qwen2-VL-2B-Instruct @@ -104,6 +123,7 @@ pytorch_chat_model: - mistralai/Mixtral-8x7B-Instruct-v0.1 - google/gemma-7b-it - google/gemma-2-9b-it + - google/gemma-2-27b-it - deepseek-ai/deepseek-moe-16b-chat - deepseek-ai/deepseek-coder-1.3b-instruct - deepseek-ai/DeepSeek-V2-Lite-Chat @@ -111,6 +131,7 @@ pytorch_chat_model: - THUDM/cogvlm2-llama3-chinese-chat-19B - THUDM/glm-4v-9b - THUDM/glm-4-9b-chat + - openbmb/MiniCPM-V-2_6 - microsoft/Phi-3-mini-4k-instruct - microsoft/Phi-3-vision-128k-instruct @@ -122,11 +143,16 @@ turbomind_vl_model: - deepseek-ai/deepseek-vl-1.3b-chat - OpenGVLab/InternVL-Chat-V1-5 - OpenGVLab/Mini-InternVL-Chat-2B-V1-5 + - OpenGVLab/InternVL2_5-1B + - OpenGVLab/InternVL2_5-8B + - OpenGVLab/InternVL2_5-26B + - OpenGVLab/InternVL2_5-38B - OpenGVLab/InternVL2-1B - OpenGVLab/InternVL2-2B - OpenGVLab/InternVL2-8B - OpenGVLab/InternVL2-26B - OpenGVLab/InternVL2-40B + - OpenGVLab/InternVL2-Llama3-76B-AWQ - internlm/internlm-xcomposer2d5-7b - internlm/internlm-xcomposer2-4khd-7b - openbmb/MiniCPM-Llama3-V-2_5 @@ -136,6 +162,10 @@ pytorch_vl_model: - meta-llama/Llama-3.2-11B-Vision-Instruct - OpenGVLab/InternVL-Chat-V1-5 - OpenGVLab/Mini-InternVL-Chat-2B-V1-5 + - OpenGVLab/InternVL2_5-1B + - OpenGVLab/InternVL2_5-8B + - OpenGVLab/InternVL2_5-26B + - OpenGVLab/InternVL2_5-38B - OpenGVLab/InternVL2-1B - OpenGVLab/InternVL2-2B - OpenGVLab/InternVL2-4B @@ -148,6 +178,7 @@ pytorch_vl_model: - THUDM/cogvlm-chat-hf - THUDM/cogvlm2-llama3-chinese-chat-19B - THUDM/glm-4v-9b + - openbmb/MiniCPM-V-2_6 - microsoft/Phi-3-vision-128k-instruct - microsoft/Phi-3.5-vision-instruct @@ -166,6 +197,8 @@ pytorch_base_model: turbomind_quatization: no_awq: + - meta-llama/Meta-Llama-3-1-70B-Instruct + - Qwen/Qwen2.5-72B-Instruct - Qwen/Qwen1.5-MoE-A2.7B-Chat - Qwen/Qwen2-VL-2B-Instruct - Qwen/Qwen2-VL-7B-Instruct @@ -174,18 +207,28 @@ turbomind_quatization: - deepseek-ai/deepseek-coder-1.3b-instruct - deepseek-ai/DeepSeek-V2-Lite-Chat - codellama/CodeLlama-7b-Instruct-hf + - allenai/Molmo-7B-D-0924 gptq: - internlm/internlm2_5-7b-chat no_kvint4: + - meta-llama/Llama-3.2-1B-Instruct + - OpenGVLab/InternVL2-1B + - OpenGVLab/InternVL2_5-1B - openbmb/MiniCPM-V-2_6 - Qwen/Qwen2-7B-Instruct - Qwen/Qwen2-7B-Instruct-AWQ - Qwen/Qwen2-1.5B-Instruct - Qwen/Qwen2.5-0.5B-Instruct - Qwen/Qwen2.5-7B-Instruct + - Qwen/Qwen2.5-72B-Instruct - Qwen/Qwen2-7B-Instruct-GPTQ-Int4 + - allenai/Molmo-7B-D-0924 no_kvint8: + - deepseek-ai/DeepSeek-V2-Chat + no_converted: - deepseek-ai/DeepSeek-V2-Lite-Chat + - Qwen/Qwen2.5-72B-Instruct + - meta-llama/Meta-Llama-3-1-70B-Instruct pytorch_quatization: awq: @@ -200,23 +243,39 @@ pytorch_quatization: - Qwen/Qwen1.5-7B-Chat - Qwen/Qwen2-7B-Instruct - Qwen/Qwen2-1.5B-Instruct - - microsoft/Phi-3-mini-4k-instruct + - Qwen/Qwen2.5-7B-Instruct - Qwen/Qwen2-VL-2B-Instruct - Qwen/Qwen2-VL-7B-Instruct + - microsoft/Phi-3-mini-4k-instruct w8a8: - meta-llama/Meta-Llama-3-8B-Instruct + - meta-llama/Llama-3.2-1B-Instruct - meta-llama/Llama-2-7b-chat-hf - internlm/internlm2-chat-20b - internlm/internlm2_5-7b-chat - internlm/internlm2_5-20b-chat - 01-ai/Yi-6B-Chat + - mistralai/Mistral-7B-Instruct-v0.3 + - Qwen/Qwen1.5-7B-Chat + - Qwen/Qwen2-7B-Instruct + - Qwen/Qwen2-1.5B-Instruct + - Qwen/Qwen2.5-7B-Instruct + - microsoft/Phi-3-mini-4k-instruct - internlm/internlm2_5-20b - internlm/internlm2_5-7b + - meta-llama/Meta-Llama-3-1-8B-Instruct no_kvint4: + - meta-llama/Llama-3.2-1B-Instruct - OpenGVLab/InternVL2-1B - OpenGVLab/InternVL2-4B + - OpenGVLab/InternVL2_5-1B - Qwen/Qwen2-7B-Instruct + - Qwen/Qwen2-7B-Instruct-AWQ - Qwen/Qwen2-1.5B-Instruct + - Qwen/Qwen2.5-0.5B-Instruct + - Qwen/Qwen2.5-7B-Instruct + - Qwen/Qwen2.5-72B-Instruct + - Qwen/Qwen2-7B-Instruct-GPTQ-Int4 - Qwen/Qwen2-VL-2B-Instruct - Qwen/Qwen2-VL-7B-Instruct - deepseek-ai/DeepSeek-V2-Lite-Chat @@ -247,3 +306,4 @@ benchmark_model: - mistralai/Mistral-7B-Instruct-v0.3 - mistralai/Mixtral-8x7B-Instruct-v0.1 - deepseek-ai/DeepSeek-V2-Lite-Chat + - allenai/Molmo-7B-D-0924 diff --git a/autotest/prompt_case.yaml b/autotest/prompt_case.yaml index 9a5ed45724..468f3e49d6 100644 --- a/autotest/prompt_case.yaml +++ b/autotest/prompt_case.yaml @@ -54,6 +54,7 @@ emoji_case: - 好 - '!' - u1f44d + - 🌟 traditional_chinese_case: - 介紹澳門景點,使用繁體: - contain: diff --git a/autotest/pytest.ini b/autotest/pytest.ini index 4c963d5bbd..69dc47fa58 100644 --- a/autotest/pytest.ini +++ b/autotest/pytest.ini @@ -5,4 +5,4 @@ python_functions = test_* # test function pytest_runtest_call.tryfirst = True filterwarnings = ignore::UserWarning reruns = 2 -reruns_delay = 10 +reruns_delay = 1 diff --git a/autotest/tools/chat/test_command_chat_hf_pytorch.py b/autotest/tools/chat/test_command_chat_hf_pytorch.py index 1ae3be338b..e6986ec614 100644 --- a/autotest/tools/chat/test_command_chat_hf_pytorch.py +++ b/autotest/tools/chat/test_command_chat_hf_pytorch.py @@ -51,6 +51,27 @@ def test_hf_pytorch_chat_tp2(config, model, cli_case_config, worker_id): assert result, msg +@pytest.mark.order(10) +@pytest.mark.usefixtures('cli_case_config') +@pytest.mark.hf_pytorch_chat +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('model', get_torch_model_list(tp_num=4)) +def test_hf_pytorch_chat_tp4(config, model, cli_case_config, worker_id): + usercase = 'chat_testcase' + result, chat_log, msg = hf_command_line_test( + config, + usercase, + cli_case_config.get(usercase), + model, + 'pytorch', + cuda_prefix=get_cuda_prefix_by_workerid(worker_id, tp_num=4)) + if chat_log is not None: + allure.attach.file(chat_log, + attachment_type=allure.attachment_type.TEXT) + + assert result, msg + + @pytest.mark.order(10) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.hf_turbomind_chat diff --git a/autotest/tools/chat/test_command_chat_hf_turbomind.py b/autotest/tools/chat/test_command_chat_hf_turbomind.py index 2f13898fec..935a21ee86 100644 --- a/autotest/tools/chat/test_command_chat_hf_turbomind.py +++ b/autotest/tools/chat/test_command_chat_hf_turbomind.py @@ -53,6 +53,28 @@ def test_hf_turbomind_chat_tp2(config, model, cli_case_config, worker_id): assert result, msg +@pytest.mark.order(10) +@pytest.mark.usefixtures('cli_case_config') +@pytest.mark.hf_turbomind_chat +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('model', get_turbomind_model_list(tp_num=4)) +def test_hf_turbomind_chat_tp4(config, model, cli_case_config, worker_id): + usercase = 'chat_testcase' + result, chat_log, msg = hf_command_line_test( + config, + usercase, + cli_case_config.get(usercase), + model, + 'turbomind', + cuda_prefix=get_cuda_prefix_by_workerid(worker_id, tp_num=4)) + + if chat_log is not None: + allure.attach.file(chat_log, + attachment_type=allure.attachment_type.TEXT) + + assert result, msg + + @pytest.mark.order(10) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.hf_turbomind_chat diff --git a/autotest/tools/chat/test_command_chat_workspace.py b/autotest/tools/chat/test_command_chat_workspace.py index a16d4e32f6..415a1c528c 100644 --- a/autotest/tools/chat/test_command_chat_workspace.py +++ b/autotest/tools/chat/test_command_chat_workspace.py @@ -9,7 +9,8 @@ @pytest.mark.usefixtures('cli_case_config') @pytest.mark.command_chat @pytest.mark.gpu_num_1 -@pytest.mark.parametrize('model', get_turbomind_model_list(tp_num=1)) +@pytest.mark.parametrize('model', + get_turbomind_model_list(tp_num=1, is_converted=True)) def test_workspace_chat_tp1(config, cli_case_config, model, worker_id): usercase = 'chat_testcase' # cannot convert with rop-scale params, so the case should be skipped @@ -32,7 +33,8 @@ def test_workspace_chat_tp1(config, cli_case_config, model, worker_id): @pytest.mark.usefixtures('cli_case_config') @pytest.mark.command_chat @pytest.mark.gpu_num_2 -@pytest.mark.parametrize('model', get_turbomind_model_list(tp_num=2)) +@pytest.mark.parametrize('model', + get_turbomind_model_list(tp_num=2, is_converted=True)) def test_workspace_chat_tp2(config, cli_case_config, model, worker_id): usercase = 'chat_testcase' result, chat_log, msg = command_line_test( @@ -54,7 +56,8 @@ def test_workspace_chat_tp2(config, cli_case_config, model, worker_id): @pytest.mark.gpu_num_1 @pytest.mark.parametrize('model', get_turbomind_model_list(tp_num=1, - model_type='base_model')) + model_type='base_model', + is_converted=True)) def test_workspace_base_tp1(config, cli_case_config, model, worker_id): usercase = 'base_testcase' result, chat_log, msg = command_line_test( @@ -76,7 +79,8 @@ def test_workspace_base_tp1(config, cli_case_config, model, worker_id): @pytest.mark.gpu_num_2 @pytest.mark.parametrize('model', get_turbomind_model_list(tp_num=2, - model_type='base_model')) + model_type='base_model', + is_converted=True)) def test_workspace_base_tp2(config, cli_case_config, model, worker_id): usercase = 'base_testcase' result, chat_log, msg = command_line_test( diff --git a/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py b/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py index 58674fa173..c0348ec500 100644 --- a/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py +++ b/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py @@ -56,6 +56,32 @@ def test_pipeline_chat_pytorch_tp2(config, common_case_config, model, worker_id) +@pytest.mark.order(6) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.pipeline_chat_pytorch +@pytest.mark.gpu_num_4 +@pytest.mark.flaky(reruns=0) +@pytest.mark.parametrize('model', + get_torch_model_list(tp_num=4, exclude_dup=True)) +def test_pipeline_chat_pytorch_tp4(config, common_case_config, model, + worker_id): + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=4) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) + spawn_context = get_context('spawn') + p = spawn_context.Process(target=run_pipeline_chat_test, + args=(config, common_case_config, model, + 'pytorch', worker_id)) + p.start() + p.join() + + # assert script + assert_pipeline_chat_log(config, common_case_config, model, 'pytorch', + worker_id) + + @pytest.mark.order(6) @pytest.mark.usefixtures('common_case_config') @pytest.mark.pipeline_chat @@ -109,6 +135,34 @@ def test_pipeline_chat_kvint4_tp2(config, common_case_config, model, 'pytorch-kvint', worker_id) +@pytest.mark.order(6) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.pipeline_chat +@pytest.mark.gpu_num_4 +@pytest.mark.flaky(reruns=0) +@pytest.mark.parametrize('model', + get_torch_model_list(tp_num=4, + quant_policy=4, + exclude_dup=True)) +def test_pipeline_chat_kvint4_tp4(config, common_case_config, model, + worker_id): + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=4) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) + spawn_context = get_context('spawn') + p = spawn_context.Process(target=run_pipeline_chat_test, + args=(config, common_case_config, model, + 'pytorch-kvint', worker_id, { + 'quant_policy': 4 + })) + p.start() + p.join() + assert_pipeline_chat_log(config, common_case_config, model, + 'pytorch-kvint', worker_id) + + @pytest.mark.order(6) @pytest.mark.usefixtures('common_case_config') @pytest.mark.pipeline_chat @@ -162,6 +216,34 @@ def test_pipeline_chat_kvint8_tp2(config, common_case_config, model, 'pytorch-kvint', worker_id) +@pytest.mark.order(6) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.pipeline_chat +@pytest.mark.gpu_num_4 +@pytest.mark.flaky(reruns=0) +@pytest.mark.parametrize('model', + get_torch_model_list(tp_num=4, + quant_policy=8, + exclude_dup=True)) +def test_pipeline_chat_kvint8_tp4(config, common_case_config, model, + worker_id): + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=4) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) + spawn_context = get_context('spawn') + p = spawn_context.Process(target=run_pipeline_chat_test, + args=(config, common_case_config, model, + 'pytorch-kvint', worker_id, { + 'quant_policy': 8 + })) + p.start() + p.join() + assert_pipeline_chat_log(config, common_case_config, model, + 'pytorch-kvint', worker_id) + + @pytest.mark.order(6) @pytest.mark.usefixtures('common_case_config') @pytest.mark.pipeline_chat_pytorch diff --git a/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py b/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py index 8403ced94f..8735b8e937 100644 --- a/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py +++ b/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py @@ -34,6 +34,27 @@ def test_pipeline_chat_tp2(config, model, worker_id): if 'gw' in worker_id: os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, tp_num=2) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) + spawn_context = get_context('spawn') + p = spawn_context.Process(target=run_pipeline_vl_chat_test, + args=(config, model, BACKEND, worker_id)) + p.start() + p.join() + assert_pipeline_vl_chat_log(config, model, worker_id) + + +@pytest.mark.order(6) +@pytest.mark.pipeline_chat +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('model', + get_torch_model_list(tp_num=4, model_type='vl_model')) +def test_pipeline_chat_tp4(config, model, worker_id): + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=4) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) spawn_context = get_context('spawn') p = spawn_context.Process(target=run_pipeline_vl_chat_test, args=(config, model, BACKEND, worker_id)) @@ -71,6 +92,29 @@ def test_pipeline_chat_kvint4_tp2(config, model, worker_id): if 'gw' in worker_id: os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, tp_num=2) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) + spawn_context = get_context('spawn') + p = spawn_context.Process(target=run_pipeline_vl_chat_test, + args=(config, model, BACKEND, worker_id, 4)) + p.start() + p.join() + assert_pipeline_vl_chat_log(config, model, worker_id) + + +@pytest.mark.order(6) +@pytest.mark.pipeline_chat +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('model', + get_torch_model_list(tp_num=4, + quant_policy=4, + model_type='vl_model')) +def test_pipeline_chat_kvint4_tp4(config, model, worker_id): + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=4) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) spawn_context = get_context('spawn') p = spawn_context.Process(target=run_pipeline_vl_chat_test, args=(config, model, BACKEND, worker_id, 4)) @@ -108,6 +152,29 @@ def test_pipeline_chat_kvint8_tp2(config, model, worker_id): if 'gw' in worker_id: os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, tp_num=2) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) + spawn_context = get_context('spawn') + p = spawn_context.Process(target=run_pipeline_vl_chat_test, + args=(config, model, BACKEND, worker_id, 8)) + p.start() + p.join() + assert_pipeline_vl_chat_log(config, model, worker_id) + + +@pytest.mark.order(6) +@pytest.mark.pipeline_chat +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('model', + get_torch_model_list(tp_num=4, + quant_policy=8, + model_type='vl_model')) +def test_pipeline_chat_kvint8_tp4(config, model, worker_id): + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=4) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) spawn_context = get_context('spawn') p = spawn_context.Process(target=run_pipeline_vl_chat_test, args=(config, model, BACKEND, worker_id, 8)) diff --git a/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py b/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py index d1865175cf..58eab0de76 100644 --- a/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py +++ b/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py @@ -48,6 +48,28 @@ def test_pipeline_chat_tp2(config, common_case_config, model, worker_id): worker_id) +@pytest.mark.order(6) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.pipeline_chat +@pytest.mark.gpu_num_4 +@pytest.mark.flaky(reruns=0) +@pytest.mark.parametrize('model', get_all_model_list(tp_num=4)) +def test_pipeline_chat_tp4(config, common_case_config, model, worker_id): + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=4) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) + spawn_context = get_context('spawn') + p = spawn_context.Process(target=run_pipeline_chat_test, + args=(config, common_case_config, model, + 'turbomind', worker_id)) + p.start() + p.join() + assert_pipeline_chat_log(config, common_case_config, model, 'turbomind', + worker_id) + + @pytest.mark.order(6) @pytest.mark.usefixtures('common_case_config') @pytest.mark.pipeline_chat @@ -95,6 +117,31 @@ def test_pipeline_chat_kvint4_tp2(config, common_case_config, model, 'turbomind-kvint', worker_id) +@pytest.mark.order(6) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.pipeline_chat +@pytest.mark.gpu_num_4 +@pytest.mark.flaky(reruns=0) +@pytest.mark.parametrize('model', get_all_model_list(tp_num=4, quant_policy=4)) +def test_pipeline_chat_kvint4_tp4(config, common_case_config, model, + worker_id): + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=4) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) + spawn_context = get_context('spawn') + p = spawn_context.Process(target=run_pipeline_chat_test, + args=(config, common_case_config, model, + 'turbomind-kvint', worker_id, { + 'quant_policy': 4 + })) + p.start() + p.join() + assert_pipeline_chat_log(config, common_case_config, model, + 'turbomind-kvint', worker_id) + + @pytest.mark.order(6) @pytest.mark.usefixtures('common_case_config') @pytest.mark.pipeline_chat @@ -142,6 +189,31 @@ def test_pipeline_chat_kvint8_tp2(config, common_case_config, model, 'turbomind-kvint', worker_id) +@pytest.mark.order(6) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.pipeline_chat +@pytest.mark.gpu_num_4 +@pytest.mark.flaky(reruns=0) +@pytest.mark.parametrize('model', get_all_model_list(tp_num=4, quant_policy=8)) +def test_pipeline_chat_kvint8_tp4(config, common_case_config, model, + worker_id): + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=4) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) + spawn_context = get_context('spawn') + p = spawn_context.Process(target=run_pipeline_chat_test, + args=(config, common_case_config, model, + 'turbomind-kvint', worker_id, { + 'quant_policy': 8 + })) + p.start() + p.join() + assert_pipeline_chat_log(config, common_case_config, model, + 'turbomind-kvint', worker_id) + + @pytest.mark.order(6) @pytest.mark.usefixtures('common_case_config') @pytest.mark.pipeline_chat diff --git a/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py b/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py index 8c845fa77a..c62bfc5e8e 100644 --- a/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py +++ b/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py @@ -34,6 +34,27 @@ def test_pipeline_chat_tp2(config, model, worker_id): if 'gw' in worker_id: os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, tp_num=2) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) + spawn_context = get_context('spawn') + p = spawn_context.Process(target=run_pipeline_vl_chat_test, + args=(config, model, BACKEND, worker_id)) + p.start() + p.join() + assert_pipeline_vl_chat_log(config, model, worker_id) + + +@pytest.mark.order(6) +@pytest.mark.pipeline_chat +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('model', + get_all_model_list(tp_num=4, model_type='vl_model')) +def test_pipeline_chat_tp4(config, model, worker_id): + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=4) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) spawn_context = get_context('spawn') p = spawn_context.Process(target=run_pipeline_vl_chat_test, args=(config, model, BACKEND, worker_id)) @@ -71,6 +92,29 @@ def test_pipeline_chat_kvint4_tp2(config, model, worker_id): if 'gw' in worker_id: os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, tp_num=2) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) + spawn_context = get_context('spawn') + p = spawn_context.Process(target=run_pipeline_vl_chat_test, + args=(config, model, BACKEND, worker_id, 4)) + p.start() + p.join() + assert_pipeline_vl_chat_log(config, model, worker_id) + + +@pytest.mark.order(6) +@pytest.mark.pipeline_chat +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('model', + get_all_model_list(tp_num=4, + quant_policy=4, + model_type='vl_model')) +def test_pipeline_chat_kvint4_tp4(config, model, worker_id): + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=4) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) spawn_context = get_context('spawn') p = spawn_context.Process(target=run_pipeline_vl_chat_test, args=(config, model, BACKEND, worker_id, 4)) @@ -108,6 +152,29 @@ def test_pipeline_chat_kvint8_tp2(config, model, worker_id): if 'gw' in worker_id: os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, tp_num=2) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) + spawn_context = get_context('spawn') + p = spawn_context.Process(target=run_pipeline_vl_chat_test, + args=(config, model, BACKEND, worker_id, 8)) + p.start() + p.join() + assert_pipeline_vl_chat_log(config, model, worker_id) + + +@pytest.mark.order(6) +@pytest.mark.pipeline_chat +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('model', + get_all_model_list(tp_num=4, + quant_policy=8, + model_type='vl_model')) +def test_pipeline_chat_kvint8_tp4(config, model, worker_id): + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=4) + os.environ['MASTER_PORT'] = str( + int(worker_id.replace('gw', '')) + 29500) spawn_context = get_context('spawn') p = spawn_context.Process(target=run_pipeline_vl_chat_test, args=(config, model, BACKEND, worker_id, 8)) diff --git a/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py b/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py index fc95e288ca..bc0ea3996a 100644 --- a/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py +++ b/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py @@ -60,6 +60,23 @@ def test_restful_chat_tp2(config, common_case_config, worker_id): port=DEFAULT_PORT + get_workerid(worker_id)) +@pytest.mark.order(7) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.restful_api_pytorch +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('prepare_environment', + getModelList(tp_num=4), + indirect=True) +def test_restful_chat_tp4(config, common_case_config, worker_id): + if get_workerid(worker_id) is None: + run_all_step(config, common_case_config) + else: + run_all_step(config, + common_case_config, + worker_id=worker_id, + port=DEFAULT_PORT + get_workerid(worker_id)) + + def getKvintModelList(tp_num, quant_policy): return [{ 'model': item, @@ -104,6 +121,23 @@ def test_restful_chat_kvint4_tp2(config, common_case_config, worker_id): port=DEFAULT_PORT + get_workerid(worker_id)) +@pytest.mark.order(7) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.restful_api +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('prepare_environment', + getKvintModelList(tp_num=4, quant_policy=4), + indirect=True) +def test_restful_chat_kvint4_tp4(config, common_case_config, worker_id): + if get_workerid(worker_id) is None: + run_all_step(config, common_case_config) + else: + run_all_step(config, + common_case_config, + worker_id=worker_id, + port=DEFAULT_PORT + get_workerid(worker_id)) + + @pytest.mark.order(7) @pytest.mark.usefixtures('common_case_config') @pytest.mark.restful_api @@ -138,6 +172,23 @@ def test_restful_chat_kvint8_tp2(config, common_case_config, worker_id): port=DEFAULT_PORT + get_workerid(worker_id)) +@pytest.mark.order(7) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.restful_api +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('prepare_environment', + getKvintModelList(tp_num=4, quant_policy=8), + indirect=True) +def test_restful_chat_kvint8_tp4(config, common_case_config, worker_id): + if get_workerid(worker_id) is None: + run_all_step(config, common_case_config) + else: + run_all_step(config, + common_case_config, + worker_id=worker_id, + port=DEFAULT_PORT + get_workerid(worker_id)) + + @pytest.mark.order(7) @pytest.mark.usefixtures('common_case_config') @pytest.mark.restful_api diff --git a/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py b/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py index bf20c45e6e..cc85d35d09 100644 --- a/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py +++ b/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py @@ -53,6 +53,19 @@ def test_restful_chat_tp2(config, worker_id): run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id)) +@pytest.mark.order(7) +@pytest.mark.restful_api_vl +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('prepare_environment', + getModelList(tp_num=4), + indirect=True) +def test_restful_chat_tp4(config, worker_id): + if get_workerid(worker_id) is None: + run_vl_testcase(config) + else: + run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id)) + + def getKvintModelList(tp_num, quant_policy: int = None): return [{ 'model': item, @@ -89,6 +102,19 @@ def test_restful_chat_kvint4_tp2(config, worker_id): run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id)) +@pytest.mark.order(7) +@pytest.mark.restful_api_vl +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('prepare_environment', + getKvintModelList(tp_num=4, quant_policy=4), + indirect=True) +def test_restful_chat_kvint4_tp4(config, worker_id): + if get_workerid(worker_id) is None: + run_vl_testcase(config) + else: + run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id)) + + @pytest.mark.order(7) @pytest.mark.restful_api_vl @pytest.mark.gpu_num_1 @@ -113,3 +139,16 @@ def test_restful_chat_kvint8_tp2(config, worker_id): run_vl_testcase(config) else: run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id)) + + +@pytest.mark.order(7) +@pytest.mark.restful_api_vl +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('prepare_environment', + getKvintModelList(tp_num=4, quant_policy=8), + indirect=True) +def test_restful_chat_kvint8_tp4(config, worker_id): + if get_workerid(worker_id) is None: + run_vl_testcase(config) + else: + run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id)) diff --git a/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py b/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py index 1c9131b32e..435ffc4ae3 100644 --- a/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py +++ b/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py @@ -60,6 +60,23 @@ def test_restful_chat_tp2(config, common_case_config, worker_id): port=DEFAULT_PORT + get_workerid(worker_id)) +@pytest.mark.order(7) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.restful_api +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('prepare_environment', + getModelList(tp_num=4), + indirect=True) +def test_restful_chat_tp4(config, common_case_config, worker_id): + if get_workerid(worker_id) is None: + run_all_step(config, common_case_config) + else: + run_all_step(config, + common_case_config, + worker_id=worker_id, + port=DEFAULT_PORT + get_workerid(worker_id)) + + def getKvintModelList(tp_num, quant_policy): return [{ 'model': item, @@ -103,6 +120,23 @@ def test_restful_chat_kvint4_tp2(config, common_case_config, worker_id): port=DEFAULT_PORT + get_workerid(worker_id)) +@pytest.mark.order(7) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.restful_api +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('prepare_environment', + getKvintModelList(tp_num=4, quant_policy=4), + indirect=True) +def test_restful_chat_kvint4_tp4(config, common_case_config, worker_id): + if get_workerid(worker_id) is None: + run_all_step(config, common_case_config) + else: + run_all_step(config, + common_case_config, + worker_id=worker_id, + port=DEFAULT_PORT + get_workerid(worker_id)) + + @pytest.mark.order(7) @pytest.mark.usefixtures('common_case_config') @pytest.mark.restful_api @@ -137,6 +171,23 @@ def test_restful_chat_kvint8_tp2(config, common_case_config, worker_id): port=DEFAULT_PORT + get_workerid(worker_id)) +@pytest.mark.order(7) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.restful_api +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('prepare_environment', + getKvintModelList(tp_num=4, quant_policy=8), + indirect=True) +def test_restful_chat_kvint8_tp4(config, common_case_config, worker_id): + if get_workerid(worker_id) is None: + run_all_step(config, common_case_config) + else: + run_all_step(config, + common_case_config, + worker_id=worker_id, + port=DEFAULT_PORT + get_workerid(worker_id)) + + @pytest.mark.order(7) @pytest.mark.usefixtures('common_case_config') @pytest.mark.restful_api diff --git a/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py b/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py index 641f2f760f..bbb8718366 100644 --- a/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py +++ b/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py @@ -53,6 +53,19 @@ def test_restful_chat_tp2(config, worker_id): run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id)) +@pytest.mark.order(7) +@pytest.mark.restful_api_vl +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('prepare_environment', + getModelList(tp_num=4), + indirect=True) +def test_restful_chat_tp4(config, worker_id): + if get_workerid(worker_id) is None: + run_vl_testcase(config) + else: + run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id)) + + def getKvintModelList(tp_num, quant_policy: int = None): return [{ 'model': item, @@ -89,6 +102,19 @@ def test_restful_chat_kvint4_tp2(config, worker_id): run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id)) +@pytest.mark.order(7) +@pytest.mark.restful_api_vl +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('prepare_environment', + getKvintModelList(tp_num=4, quant_policy=4), + indirect=True) +def test_restful_chat_kvint4_tp4(config, worker_id): + if get_workerid(worker_id) is None: + run_vl_testcase(config) + else: + run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id)) + + @pytest.mark.order(7) @pytest.mark.restful_api_vl @pytest.mark.gpu_num_1 @@ -113,3 +139,16 @@ def test_restful_chat_kvint8_tp2(config, worker_id): run_vl_testcase(config) else: run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id)) + + +@pytest.mark.order(7) +@pytest.mark.restful_api_vl +@pytest.mark.gpu_num_4 +@pytest.mark.parametrize('prepare_environment', + getKvintModelList(tp_num=4, quant_policy=8), + indirect=True) +def test_restful_chat_kvint8_tp4(config, worker_id): + if get_workerid(worker_id) is None: + run_vl_testcase(config) + else: + run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id)) diff --git a/autotest/tools/restful/test_restful_chat_workspace.py b/autotest/tools/restful/test_restful_chat_workspace.py index 798a43d7b0..cf69007cca 100644 --- a/autotest/tools/restful/test_restful_chat_workspace.py +++ b/autotest/tools/restful/test_restful_chat_workspace.py @@ -23,8 +23,7 @@ def getModelList(tp_num): 'model': item, 'cuda_prefix': None, 'tp_num': tp_num - } for item in get_turbomind_model_list(tp_num) - if item not in ('deepseek-ai/deepseek-coder-1.3b-instruct')] + } for item in get_turbomind_model_list(tp_num, is_converted=True)] @pytest.mark.order(7) diff --git a/autotest/utils/config_utils.py b/autotest/utils/config_utils.py index 24b4a3f8cd..87d5d73f10 100644 --- a/autotest/utils/config_utils.py +++ b/autotest/utils/config_utils.py @@ -9,17 +9,33 @@ def get_turbomind_model_list(tp_num: int = None, model_type: str = 'chat_model', - quant_policy: int = None): + quant_policy: int = None, + is_converted: bool = False): config = get_config() if quant_policy is None: - case_list = copy.deepcopy(config.get('turbomind_' + model_type)) + if is_converted: + case_list = [ + x for x in copy.deepcopy(config.get('turbomind_' + model_type)) + if x not in config.get('turbomind_quatization').get( + 'no_converted') + ] + else: + case_list = copy.deepcopy(config.get('turbomind_' + model_type)) else: - case_list = [ - x for x in config.get('turbomind_' + model_type) - if x not in config.get('turbomind_quatization').get( - 'no_kvint' + str(quant_policy)) - ] + if is_converted: + case_list = [ + x for x in config.get('turbomind_' + model_type) + if x not in config.get('turbomind_quatization').get( + 'no_kvint' + str(quant_policy) and x not in config.get( + 'turbomind_quatization').get('no_converted')) + ] + else: + case_list = [ + x for x in config.get('turbomind_' + model_type) + if x not in config.get('turbomind_quatization').get( + 'no_kvint' + str(quant_policy)) + ] quatization_case_config = config.get('turbomind_quatization') for key in config.get('turbomind_' + model_type): @@ -97,7 +113,7 @@ def get_all_model_list(tp_num: int = None, model_type=model_type): if case not in case_list: case_list.append(case) - return [x for x in case_list if 'w8a8' not in x] + return case_list def get_quantization_model_list(type): @@ -202,6 +218,7 @@ def get_benchmark_model_list(tp_num, else: case_list_base = config.get('benchmark_model') quatization_case_config = config.get('turbomind_quatization') + pytorch_quatization_case_config = config.get('pytorch_quatization') case_list = copy.deepcopy(case_list_base) for key in case_list_base: @@ -210,6 +227,12 @@ def get_benchmark_model_list(tp_num, 'no_awq') and not is_quantization_model(key): case_list.append(key + '-inner-4bits') + for key in case_list_base: + if key in config.get('pytorch_chat_model' + ) and key in pytorch_quatization_case_config.get( + 'w8a8') and not is_quantization_model(key): + case_list.append(key + '-inner-w8a8') + model_list = [ item for item in case_list if get_tp_num(config, item) == tp_num ] @@ -228,15 +251,18 @@ def get_benchmark_model_list(tp_num, 'backend': 'pytorch', 'tp_num': tp_num } for item in model_list if '4bits' not in item and ( - item in config.get('pytorch_chat_model') or tp_num == 4)] + item.replace('-inner-w8a8', '') in config.get('pytorch_chat_model') + or tp_num == 4)] for kvint in kvint_list: result += [{ 'model': item, 'backend': 'turbomind', 'quant_policy': kvint, 'tp_num': tp_num - } for item in model_list if item.replace('-inner-4bits', '') in - config.get('turbomind_chat_model')] + } for item in model_list if item.replace( + '-inner-4bits', '') in config.get('turbomind_chat_model') + and item.replace('-inner-4bits', '') not in + quatization_case_config.get('no_kvint' + str(kvint))] return result diff --git a/autotest/utils/pipeline_chat.py b/autotest/utils/pipeline_chat.py index 023e4ac142..8f03e4e406 100644 --- a/autotest/utils/pipeline_chat.py +++ b/autotest/utils/pipeline_chat.py @@ -277,14 +277,14 @@ def assert_pipeline_single_element(output, return result -PIC1 = 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg' # noqa E501 -PIC2 = 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg' # noqa E501 -PIC_BEIJING = 'https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg' # noqa E501 -PIC_CHONGQING = 'https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg' # noqa E501 -PIC_REDPANDA = 'https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg' # noqa E501 -PIC_PANDA = 'https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg' # noqa E501 -DESC = 'What are the similarities and differences between these two images.' # noqa E501 -DESC_ZH = '两张图有什么相同和不同的地方.' # noqa E501 +PIC1 = 'tiger.jpeg' +PIC2 = 'human-pose.jpg' +PIC_BEIJING = 'Beijing_Small.jpeg' +PIC_CHONGQING = 'Chongqing_Small.jpeg' +PIC_REDPANDA = 'redpanda.jpg' +PIC_PANDA = 'panda.jpg' +DESC = 'What are the similarities and differences between these two images.' +DESC_ZH = '两张图有什么相同和不同的地方.' def run_pipeline_vl_chat_test(config, @@ -296,6 +296,7 @@ def run_pipeline_vl_chat_test(config, tp = get_tp_num(config, model_case) model_path = config.get('model_path') hf_path = model_path + '/' + model_case + resource_path = config.get('resource_path') if 'pytorch' in backend: backend_config = PytorchEngineConfig(tp=tp, session_len=8192) @@ -320,7 +321,7 @@ def run_pipeline_vl_chat_test(config, 'pipeline_vl_chat_' + model_case.split('/')[1] + worker_id + '.log') file = open(pipeline_chat_log, 'w') - image = load_image(PIC1) + image = load_image(f'{resource_path}/{PIC1}') if 'deepseek' in model_case: prompt = f'describe this image{IMAGE_TOKEN}' @@ -352,7 +353,7 @@ def run_pipeline_vl_chat_test(config, }, { 'type': 'image_url', 'image_url': { - 'url': PIC1 + 'url': f'{resource_path}/{PIC1}' } }] }] @@ -362,7 +363,7 @@ def run_pipeline_vl_chat_test(config, ', reason: OpenAI format example: tiger not in ' + response.text + '\n') - image_urls = [PIC2, PIC1] + image_urls = [f'{resource_path}/{PIC2}', f'{resource_path}/{PIC1}'] images = [load_image(img_url) for img_url in image_urls] response = pipe((prompt, images)) result = 'tiger' in response.text.lower() or 'ski' in response.text.lower( @@ -371,7 +372,7 @@ def run_pipeline_vl_chat_test(config, ', reason: Multi-images example: tiger or ski not in ' + response.text + '\n') - image_urls = [PIC2, PIC1] + image_urls = [f'{resource_path}/{PIC2}', f'{resource_path}/{PIC1}'] prompts = [(prompt, load_image(img_url)) for img_url in image_urls] response = pipe(prompts) result = ('ski' in response[0].text.lower() @@ -382,7 +383,7 @@ def run_pipeline_vl_chat_test(config, ', reason: Batch example: ski or tiger not in ' + str(response) + '\n') - image = load_image(PIC2) + image = load_image(f'{resource_path}/{PIC2}') sess = pipe.chat((prompt, image)) result = 'ski' in sess.response.text.lower( ) or '滑雪' in sess.response.text.lower() @@ -397,12 +398,12 @@ def run_pipeline_vl_chat_test(config, sess.response.text + '\n') if 'internvl' in model_case.lower(): - internvl_vl_testcase(config, pipe, file) - internvl_vl_testcase(config, pipe, file, 'cn') + internvl_vl_testcase(pipe, file, resource_path) + internvl_vl_testcase(pipe, file, resource_path, 'cn') if 'minicpm' in model_case.lower(): - MiniCPM_vl_testcase(config, pipe, file) + MiniCPM_vl_testcase(pipe, file, resource_path) if 'qwen' in model_case.lower(): - Qwen_vl_testcase(config, pipe, file) + Qwen_vl_testcase(pipe, file, resource_path) file.close() @@ -410,7 +411,7 @@ def run_pipeline_vl_chat_test(config, torch.cuda.empty_cache() -def internvl_vl_testcase(config, pipe, file, lang='en'): +def internvl_vl_testcase(pipe, file, resource_path, lang='en'): if lang == 'cn': description = DESC_ZH else: @@ -422,9 +423,11 @@ def internvl_vl_testcase(config, pipe, file, lang='en'): dict(type='text', text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\n{description}'), dict(type='image_url', - image_url=dict(max_dynamic_patch=12, url=PIC_REDPANDA)), + image_url=dict(max_dynamic_patch=12, + url=f'{resource_path}/{PIC_REDPANDA}')), dict(type='image_url', - image_url=dict(max_dynamic_patch=12, url=PIC_PANDA)) + image_url=dict(max_dynamic_patch=12, + url=f'{resource_path}/{PIC_PANDA}')) ]) ] response = pipe(messages) @@ -452,9 +455,11 @@ def internvl_vl_testcase(config, pipe, file, lang='en'): + # noqa E251,E501 description), dict(type='image_url', - image_url=dict(max_dynamic_patch=12, url=PIC_REDPANDA)), + image_url=dict(max_dynamic_patch=12, + url=f'{resource_path}/{PIC_REDPANDA}')), dict(type='image_url', - image_url=dict(max_dynamic_patch=12, url=PIC_PANDA)) + image_url=dict(max_dynamic_patch=12, + url=f'{resource_path}/{PIC_PANDA}')) ]) ] response = pipe(messages) @@ -501,8 +506,7 @@ def load_video(video_path, bound=None, num_segments=32): imgs.append(img) return imgs - resource_path = config.get('resource_path') - video_path = resource_path + '/red-panda.mp4' + video_path = f'{resource_path}/red-panda.mp4' imgs = load_video(video_path, num_segments=8) question = '' @@ -546,14 +550,16 @@ def load_video(video_path, bound=None, num_segments=32): response.text + '\n') -def llava_vl_testcase(config, pipe, file): +def llava_vl_testcase(pipe, file, resource_path): # multi-image multi-round conversation, combined images messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), - dict(type='image_url', image_url=dict(url=PIC_BEIJING)), - dict(type='image_url', image_url=dict(url=PIC_CHONGQING)) + dict(type='image_url', + image_url=dict(url=f'{resource_path}/{PIC_BEIJING}')), + dict(type='image_url', + image_url=dict(url=f'{resource_path}/{PIC_CHONGQING}')) ]) ] response = pipe(messages) @@ -575,16 +581,18 @@ def llava_vl_testcase(config, pipe, file): response.text + '\n') -def MiniCPM_vl_testcase(config, pipe, file): +def MiniCPM_vl_testcase(pipe, file, resource_path): # Chat with multiple images messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', - image_url=dict(max_slice_nums=9, url=PIC_REDPANDA)), + image_url=dict(max_slice_nums=9, + url=f'{resource_path}/{PIC_REDPANDA}')), dict(type='image_url', - image_url=dict(max_slice_nums=9, url=PIC_PANDA)) + image_url=dict(max_slice_nums=9, + url=f'{resource_path}/{PIC_PANDA}')) ]) ] response = pipe(messages) @@ -602,27 +610,27 @@ def MiniCPM_vl_testcase(config, pipe, file): response.text + '\n') # In-context few-shot learning - EXAMPLE1 = 'https://github.com/user-attachments/assets/405d9147-95f6-4f78-8879-606a0aed6707' # noqa E251,E501 - EXAMPLE2 = 'https://github.com/user-attachments/assets/9f2c6ed9-2aa5-4189-9c4f-0b9753024ba1' # noqa E251,E501 - EXAMPLE3 = 'https://github.com/user-attachments/assets/f335b507-1957-4c22-84ae-ed69ff79df38' # noqa E251,E501 question = 'production date' messages = [ dict(role='user', content=[ dict(type='text', text=question), - dict(type='image_url', image_url=dict(url=EXAMPLE1)), + dict(type='image_url', + image_url=dict(url=f'{resource_path}/data1.jpeg')), ]), dict(role='assistant', content='2021.08.29'), dict(role='user', content=[ dict(type='text', text=question), - dict(type='image_url', image_url=dict(url=EXAMPLE2)), + dict(type='image_url', + image_url=dict(url=f'{resource_path}/data2.jpeg')), ]), dict(role='assistant', content='1999.05.15'), dict(role='user', content=[ dict(type='text', text=question), - dict(type='image_url', image_url=dict(url=EXAMPLE3)), + dict(type='image_url', + image_url=dict(url=f'{resource_path}/data3.jpeg')), ]) ] response = pipe(messages) @@ -651,8 +659,7 @@ def uniform_sample(length, n): print('num frames:', len(frames)) return frames - resource_path = config.get('resource_path') - video_path = resource_path + '/red-panda.mp4' + video_path = f'{resource_path}red-panda.mp4' frames = encode_video(video_path) question = 'Describe the video' @@ -675,14 +682,16 @@ def uniform_sample(length, n): '\n') -def Qwen_vl_testcase(config, pipe, file): +def Qwen_vl_testcase(pipe, file, resource_path): # multi-image multi-round conversation, combined images messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), - dict(type='image_url', image_url=dict(url=PIC_BEIJING)), - dict(type='image_url', image_url=dict(url=PIC_CHONGQING)) + dict(type='image_url', + image_url=dict(url=f'{resource_path}/{PIC_BEIJING}')), + dict(type='image_url', + image_url=dict(url=f'{resource_path}/{PIC_CHONGQING}')) ]) ] response = pipe(messages) @@ -713,11 +722,11 @@ def Qwen_vl_testcase(config, pipe, file): dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, - url=PIC_BEIJING)), + url=f'{resource_path}/{PIC_BEIJING}')), dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, - url=PIC_CHONGQING)) + url=f'{resource_path}/{PIC_CHONGQING}')) ]) ] response = pipe(messages) diff --git a/benchmark/profile_generation.py b/benchmark/profile_generation.py index 952de5d9f7..6c33b8bc4b 100644 --- a/benchmark/profile_generation.py +++ b/benchmark/profile_generation.py @@ -1,11 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import asyncio import csv import os import time from dataclasses import dataclass -from queue import Queue -from threading import Thread from typing import List, Union import numpy as np @@ -24,8 +23,9 @@ os.environ['TM_LOG_LEVEL'] = 'ERROR' -def infer(model, session_id: int, input_ids: List, - gen_config: GenerationConfig, test_round: int, que: Queue): +async def infer(model, session_id: int, input_ids: List, + gen_config: GenerationConfig, test_round: int, + que: asyncio.Queue): if session_id == 1: pbar = tqdm(total=test_round) chatbot = model.create_instance() @@ -47,12 +47,12 @@ def infer(model, session_id: int, input_ids: List, The time elapsing in this iteration `now-prev` is set to the latency of first token of the 5 tokens, i.e. `token_latency_stats[0]`, and `token_latency_stats[1:4]` is set 0` """ # noqa: E501 - for outputs in chatbot.stream_infer(session_id, - input_ids, - gen_config=gen_config, - sequence_start=True, - sequence_end=True, - stream_output=True): + async for outputs in chatbot.async_stream_infer(session_id, + input_ids, + gen_config=gen_config, + sequence_start=True, + sequence_end=True, + stream_output=True): n_token = outputs.num_token now = time.perf_counter() if n_prev_token != n_token: @@ -61,7 +61,7 @@ def infer(model, session_id: int, input_ids: List, prev = now # for pytorch engine to restart a session if hasattr(chatbot, 'end'): - chatbot.end(session_id) + await chatbot.async_end(session_id) if session_id == 1: pbar.update(1) @@ -69,39 +69,42 @@ def infer(model, session_id: int, input_ids: List, f'Error. session_id({session_id}) request {output_seqlen} ' \ f'tokens, but generate {n_token} tokens' stats.append(token_latency_stats[:output_seqlen]) - que.put((session_id, stats)) + await que.put((session_id, stats)) def warmup(model, concurrency: int, input_ids: List[int], warmup_round: int, - gen_config: GenerationConfig): + gen_config: GenerationConfig, event_loop: asyncio.BaseEventLoop): if not warmup_round: return print('start to warmup ...') - def _infer(model, session_id): + async def _infer(model, session_id): chatbot = model.create_instance() for _ in range(warmup_round): - for _ in chatbot.stream_infer(session_id, - input_ids=input_ids, - sequence_start=True, - sequence_end=True, - ignore_eos=True, - gen_config=gen_config): + async for _ in chatbot.async_stream_infer(session_id, + input_ids=input_ids, + sequence_start=True, + sequence_end=True, + ignore_eos=True, + gen_config=gen_config): continue # for pytorch engine to restart a session if hasattr(chatbot, 'end'): - chatbot.end(session_id) + await chatbot.async_end(session_id) _start = time.perf_counter() - procs = [] + + # start threads + tasks = [] for i in range(concurrency): - proc = Thread(target=_infer, args=(model, i + 1), daemon=True) - procs.append(proc) - proc.start() + task = _infer(model, i + 1) + tasks.append(task) + + async def _gather_tasks(tasks): + return await asyncio.gather(*tasks) - for proc in procs: - proc.join() + event_loop.run_until_complete(_gather_tasks(tasks)) _end = time.perf_counter() print(f'end warmup, elapsed time: {round(_end - _start, 2)}s') @@ -125,31 +128,34 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int, from lmdeploy.pytorch.engine import Engine tm_model = Engine(model_path, engine_config) + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + # make up a dummy `input_ids` with the length of `input_seqlen` exactly assert input_seqlen > 0, 'input_seqlen should > 0' input_ids = np.random.randint(low=0, high=101, size=input_seqlen).tolist() - warmup(tm_model, concurrency, input_ids, warmup_round, gen_config) + warmup(tm_model, concurrency, input_ids, warmup_round, gen_config, + event_loop) - que = Queue() - procs = [] + que = asyncio.Queue() _start = time.perf_counter() + tasks = [] for i in range(concurrency): - proc = Thread(target=infer, - args=(tm_model, i + 1, input_ids, gen_config, test_round, - que)) - procs.append(proc) - proc.start() + task = infer(tm_model, i + 1, input_ids, gen_config, test_round, que) + tasks.append(task) + + async def _gather_tasks(tasks): + return await asyncio.gather(*tasks) - for proc in procs: - proc.join() + event_loop.run_until_complete(_gather_tasks(tasks)) _end = time.perf_counter() elapsed_time = _end - _start token_latency_stats = [] while not que.empty(): - _, _stats = que.get() + _, _stats = que.get_nowait() token_latency_stats += _stats # The shape is [concurrency*test_round, output_seqlen] @@ -426,7 +432,6 @@ def main(): block_size=args.cache_block_seq_len, session_len=session_len, tp=args.tp, - thread_safe=True, eager_mode=args.eager_mode, enable_prefix_caching=args.enable_prefix_caching, dtype=args.dtype, diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index 4f06fad4f9..291b1be9b8 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -345,12 +345,14 @@ def main(): requests = sample_requests(args.dataset, args.num_prompts, engine.tokenizer) - engine.process_request(requests, - temperature=args.temperature, - top_p=args.top_p, - top_k=args.top_k, - concurrency=args.concurrency, - stream_output=True) + engine.process_request( + requests, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + concurrency=args.concurrency + if args.concurrency < args.num_prompts else args.num_prompts, + stream_output=True) if __name__ == '__main__': diff --git a/docker/Dockerfile b/docker/Dockerfile index caa58ee637..24b2b055da 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -10,9 +10,6 @@ FROM ${CUDA_VERSION} AS final ARG PYTHON_VERSION=3.10 -ARG TORCH_VERSION=2.3.0 -ARG TORCHVISION_VERSION=0.18.0 - RUN apt-get update -y && apt-get install -y software-properties-common wget vim git curl openssh-server ssh sudo &&\ curl https://sh.rustup.rs -sSf | sh -s -- -y &&\ add-apt-repository ppa:deadsnakes/ppa -y && apt-get update -y && apt-get install -y --no-install-recommends \ @@ -43,7 +40,6 @@ ENV LD_LIBRARY_PATH=/usr/local/nccl/lib:$LD_LIBRARY_PATH RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install --upgrade pip setuptools==69.5.1 &&\ - python3 -m pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} --index-url https://download.pytorch.org/whl/${CUDA_VERSION_SHORT} &&\ python3 -m pip install cmake packaging wheel ENV NCCL_LAUNCH_MODE=GROUP @@ -54,7 +50,7 @@ COPY . /opt/lmdeploy WORKDIR /opt/lmdeploy RUN --mount=type=cache,target=/root/.cache/pip cd /opt/lmdeploy &&\ - python3 -m pip install -r requirements.txt &&\ + python3 -m pip install -r requirements_cuda.txt --extra-index-url https://download.pytorch.org/whl/${CUDA_VERSION_SHORT} &&\ mkdir -p build && cd build &&\ sh ../generate.sh &&\ ninja -j$(nproc) && ninja install &&\ diff --git a/docker/Dockerfile_aarch64_ascend b/docker/Dockerfile_aarch64_ascend index 1c9591197b..5ed842061c 100644 --- a/docker/Dockerfile_aarch64_ascend +++ b/docker/Dockerfile_aarch64_ascend @@ -121,5 +121,4 @@ COPY --from=copy_temp /tmp /opt/lmdeploy WORKDIR /opt/lmdeploy RUN --mount=type=cache,target=/root/.cache/pip \ - sed -i '/triton/d' requirements/runtime.txt && \ - pip3 install -v --no-build-isolation -e . + LMDEPLOY_TARGET_DEVICE=ascend pip3 install -v --no-build-isolation -e . diff --git a/docs/en/advance/pytorch_multithread.md b/docs/en/advance/pytorch_multithread.md new file mode 100644 index 0000000000..446e0fa769 --- /dev/null +++ b/docs/en/advance/pytorch_multithread.md @@ -0,0 +1,78 @@ +# PyTorchEngine Multithread + +We have removed `thread_safe` mode from PytorchEngine since [PR2907](https://github.com/InternLM/lmdeploy/pull/2907). We encourage users to achieve high concurrency by using **service API** or **coroutines** whenever possible, for example: + +```python +import asyncio +from lmdeploy import pipeline, PytorchEngineConfig + +event_loop = asyncio.new_event_loop() +asyncio.set_event_loop(event_loop) + +model_path = 'Llama-3.2-1B-Instruct' +pipe = pipeline(model_path, backend_config=PytorchEngineConfig()) + +async def _gather_output(): + tasks = [ + pipe.async_batch_infer('Hakuna Matata'), + pipe.async_batch_infer('giraffes are heartless creatures'), + ] + return await asyncio.gather(*tasks) + +output = asyncio.run(_gather_output()) +print(output[0].text) +print(output[1].text) +``` + +If you do need multithreading, it would be easy to warp it like below: + +```python +import threading +from queue import Queue +import asyncio +from lmdeploy import pipeline, PytorchEngineConfig + +model_path = 'Llama-3.2-1B-Instruct' + + +async def _batch_infer(inque: Queue, outque: Queue, pipe): + while True: + if inque.empty(): + await asyncio.sleep(0) + continue + + input = inque.get_nowait() + output = await pipe.async_batch_infer(input) + outque.put(output) + + +def server(inques, outques): + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + pipe = pipeline(model_path, backend_config=PytorchEngineConfig()) + for inque, outque in zip(inques, outques): + event_loop.create_task(_batch_infer(inque, outque, pipe)) + event_loop.run_forever() + +def client(inque, outque, message): + inque.put(message) + print(outque.get().text) + + +inques = [Queue(), Queue()] +outques = [Queue(), Queue()] + +t_server = threading.Thread(target=server, args=(inques, outques)) +t_client0 = threading.Thread(target=client, args=(inques[0], outques[0], 'Hakuna Matata')) +t_client1 = threading.Thread(target=client, args=(inques[1], outques[1], 'giraffes are heartless creatures')) + +t_server.start() +t_client0.start() +t_client1.start() + +t_client0.join() +t_client1.join() +``` + +> \[!WARNING\] +> This is NOT recommended, as multithreading introduces additional overhead, leading to unstable inference performance. diff --git a/docs/en/get_started/ascend/get_started.md b/docs/en/get_started/ascend/get_started.md index 23b86afa61..7da28b5512 100644 --- a/docs/en/get_started/ascend/get_started.md +++ b/docs/en/get_started/ascend/get_started.md @@ -18,7 +18,7 @@ cd lmdeploy ### Environment Preparation -The Docker version is supposed to be no less than `18.03`. And `Ascend Docker Runtime` should be installed by following [the official guide](https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/clusterscheduling/clusterschedulingig/.clusterschedulingig/dlug_installation_012.html). +The Docker version is supposed to be no less than `18.09`. And `Ascend Docker Runtime` should be installed by following [the official guide](https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/clusterscheduling/clusterschedulingig/.clusterschedulingig/dlug_installation_012.html). > \[!CAUTION\] > If error message `libascend_hal.so: cannot open shared object file` shows, that means **Ascend Docker Runtime** is not installed correctly! @@ -136,3 +136,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu ``` Please check [supported_models](../../supported_models/supported_models.md) before use this feature. + +### int8 KV-cache Quantization + +Ascend backend has supported offline int8 KV-cache Quantization on eager mode. + +Please refer this [doc](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md) for details. diff --git a/docs/en/get_started/installation.md b/docs/en/get_started/installation.md index b3e8bb8abd..8877d510cc 100644 --- a/docs/en/get_started/installation.md +++ b/docs/en/get_started/installation.md @@ -23,7 +23,7 @@ pip install lmdeploy The default prebuilt package is compiled on **CUDA 12**. If CUDA 11+ (>=11.3) is required, you can install lmdeploy by: ```shell -export LMDEPLOY_VERSION=0.6.3 +export LMDEPLOY_VERSION=0.6.5 export PYTHON_VERSION=38 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 ``` diff --git a/docs/en/index.rst b/docs/en/index.rst index 5d49e01c86..54a36c22c8 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -103,6 +103,7 @@ Documentation advance/chat_template.md advance/debug_turbomind.md advance/structed_output.md + advance/pytorch_multithread.md .. toctree:: :maxdepth: 1 diff --git a/docs/en/llm/api_server.md b/docs/en/llm/api_server.md index 285b0e32ff..274ec2ff25 100644 --- a/docs/en/llm/api_server.md +++ b/docs/en/llm/api_server.md @@ -249,6 +249,57 @@ curl http://{server_ip}:{server_port}/v1/chat/interactive \ lmdeploy serve gradio api_server_url --server-name ${gradio_ui_ip} --server-port ${gradio_ui_port} ``` +## Launch multiple api servers + +Following are two steps to launch multiple api servers through torchrun. Just create a python script with the following codes. + +1. Launch the proxy server through `lmdeploy serve proxy`. Get the correct proxy server url. +2. Launch the script through `torchrun --nproc_per_node 2 script.py InternLM/internlm2-chat-1_8b --proxy_url http://{proxy_node_name}:{proxy_node_port}`.**Note**: Please do not use `0.0.0.0:8000` here, instead, we input the real ip name, `11.25.34.55:8000` for example. + +```python +import os +import socket +from typing import List, Literal + +import fire + + +def get_host_ip(): + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(('8.8.8.8', 80)) + ip = s.getsockname()[0] + finally: + s.close() + return ip + + +def main(model_path: str, + tp: int = 1, + proxy_url: str = 'http://0.0.0.0:8000', + port: int = 23333, + backend: Literal['turbomind', 'pytorch'] = 'turbomind'): + local_rank = int(os.environ.get('LOCAL_RANK', -1)) + world_size = int(os.environ.get('WORLD_SIZE', -1)) + local_ip = get_host_ip() + if isinstance(port, List): + assert len(port) == world_size + port = port[local_rank] + else: + port += local_rank * 10 + if (world_size - local_rank) % tp == 0: + rank_list = ','.join([str(local_rank + i) for i in range(tp)]) + command = f'CUDA_VISIBLE_DEVICES={rank_list} lmdeploy serve api_server {model_path} '\ + f'--server-name {local_ip} --server-port {port} --tp {tp} '\ + f'--proxy-url {proxy_url} --backend {backend}' + print(f'running command: {command}') + os.system(command) + + +if __name__ == '__main__': + fire.Fire(main) +``` + ## FAQ 1. When user got `"finish_reason":"length"`, it means the session is too long to be continued. The session length can be diff --git a/docs/en/multi_modal/llava.md b/docs/en/multi_modal/llava.md index 8f052227d5..c374b67121 100644 --- a/docs/en/multi_modal/llava.md +++ b/docs/en/multi_modal/llava.md @@ -6,11 +6,17 @@ LMDeploy supports the following llava series of models, which are detailed in th | :----------------------------------: | :--: | :------------------------: | | llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch | | llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch | -| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind, PyTorch | -| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind, PyTorch | +| llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch | +| llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch | +| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind | +| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind | The next chapter demonstrates how to deploy an Llava model using LMDeploy, with [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) as an example. +```{note} +PyTorch engine removes the support of original llava models after v0.6.4. Please use their corresponding transformers models instead, which can be found in https://huggingface.co/llava-hf +``` + ## Installation Please install LMDeploy by following the [installation guide](../get_started/installation.md). diff --git a/docs/en/multi_modal/qwen2_vl.md b/docs/en/multi_modal/qwen2_vl.md index 8b59f84545..fd9f02abaa 100644 --- a/docs/en/multi_modal/qwen2_vl.md +++ b/docs/en/multi_modal/qwen2_vl.md @@ -4,7 +4,7 @@ LMDeploy supports the following Qwen-VL series of models, which are detailed in | Model | Size | Supported Inference Engine | | :----------: | :----: | :------------------------: | -| Qwen-VL-Chat | - | TurboMind, Pytorch | +| Qwen-VL-Chat | - | TurboMind | | Qwen2-VL | 2B, 7B | PyTorch | The next chapter demonstrates how to deploy an Qwen-VL model using LMDeploy, with [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) as an example. diff --git a/docs/en/quantization/w4a16.md b/docs/en/quantization/w4a16.md index 0aa1e17a5b..c36c3736c6 100644 --- a/docs/en/quantization/w4a16.md +++ b/docs/en/quantization/w4a16.md @@ -128,3 +128,7 @@ We benchmarked the Llama-2-7B-chat and Llama-2-13B-chat models with 4-bit quanti | ---------------- | ------- | ------- | --------- | | Llama-2-7B-chat | 112.9 | 159.4 | 206.4 | | Llama-2-13B-chat | N/A | 90.7 | 115.8 | + +## FAQs + +1. Out of Memory error during quantization due to insufficient GPU memory: This can be addressed by reducing the parameter `--calib-seqlen`, increasing the parameter `--calib-samples`, and set `--batch-size` to 1. diff --git a/docs/en/quantization/w8a8.md b/docs/en/quantization/w8a8.md index 1b1726bd5f..5cdb48f764 100644 --- a/docs/en/quantization/w8a8.md +++ b/docs/en/quantization/w8a8.md @@ -1,55 +1,74 @@ # SmoothQuant -LMDeploy provides functions for quantization and inference of large language models using 8-bit integers. +LMDeploy provides functions for quantization and inference of large language models using 8-bit integers(INT8). For GPUs such as Nvidia H100, lmdeploy also supports 8-bit floating point(FP8). -Before starting inference, ensure that lmdeploy and openai/triton are correctly installed. Execute the following commands to install these: +And the following NVIDIA GPUs are available for INT8/FP8 inference respectively: + +- INT8 + - V100(sm70): V100 + - Turing(sm75): 20 series, T4 + - Ampere(sm80,sm86): 30 series, A10, A16, A30, A100 + - Ada Lovelace(sm89): 40 series + - Hopper(sm90): H100 +- FP8 + - Ada Lovelace(sm89): 40 series + - Hopper(sm90): H100 + +First of all, run the following command to install lmdeploy: ```shell -pip install lmdeploy -pip install triton>=2.1.0 +pip install lmdeploy[all] ``` -## 8-bit Weight Model Inference +## 8-bit Weight Quantization -For performing 8-bit weight model inference, you can directly download the pre-quantized 8-bit weight models from LMDeploy's [model zoo](https://huggingface.co/lmdeploy). For instance, the 8-bit Internlm-chat-7B model is available for direct download from the model zoo: +Performing 8-bit weight quantization involves three steps: -```shell -git-lfs install -git clone https://huggingface.co/lmdeploy/internlm-chat-7b-w8 (coming soon) -``` +1. **Smooth Weights**: Start by smoothing the weights of the Language Model (LLM). This process makes the weights more amenable to quantizing. +2. **Replace Modules**: Locate DecoderLayers and replace the modules RSMNorm and nn.Linear with QRSMNorm and QLinear modules respectively. These 'Q' modules are available in the lmdeploy/pytorch/models/q_modules.py file. +3. **Save the Quantized Model**: Once you've made the necessary replacements, save the new quantized model. -Alternatively, you can manually convert original 16-bit weights into 8-bit by referring to the content under the ["8bit Weight Quantization"](#8bit-weight-quantization) section. Save them in the internlm-chat-7b-w8 directory, using the command below: +lmdeploy provides `lmdeploy lite smooth_quant` command to accomplish all three tasks detailed above. Note that the argument `--quant-dtype` is used to determine if you are doing int8 or fp8 weight quantization. To get more info about usage of the cli, run `lmdeploy lite smooth_quant --help` -```shell -lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8 -``` +Here are two examples: -Afterwards, use the following command to interact with the model via the terminal: +- int8 -```shell -lmdeploy chat ./internlm-chat-7b-w8 --backend pytorch -``` + ```shell + lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-int8 --quant-dtype int8 + ``` -## Launching gradio service +- fp8 -Coming soon... + ```shell + lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-fp8 --quant-dtype fp8 + ``` -## Inference Speed +## Inference -Coming soon... +Trying the following codes, you can perform the batched offline inference with the quantized model: -## 8bit Weight Quantization +```python +from lmdeploy import pipeline, PytorchEngineConfig -Performing 8bit weight quantization involves three steps: +engine_config = PytorchEngineConfig(tp=1) +pipe = pipeline("internlm2_5-7b-chat-int8", backend_config=engine_config) +response = pipe(["Hi, pls intro yourself", "Shanghai is"]) +print(response) +``` -1. **Smooth Weights**: Start by smoothing the weights of the Language Model (LLM). This process makes the weights more amenable to quantizing. -2. **Replace Modules**: Locate DecoderLayers and replace the modules RSMNorm and nn.Linear with QRSMNorm and QLinear modules respectively. These 'Q' modules are available in the lmdeploy/pytorch/models/q_modules.py file. -3. **Save the Quantized Model**: Once you've made the necessary replacements, save the new quantized model. +## Service + +LMDeploy's `api_server` enables models to be easily packed into services with a single command. The provided RESTful APIs are compatible with OpenAI's interfaces. Below are an example of service startup: + +```shell +lmdeploy serve api_server ./internlm2_5-7b-chat-int8 --backend pytorch +``` -The script `lmdeploy/lite/apis/smooth_quant.py` accomplishes all three tasks detailed above. For example, you can obtain the model weights of the quantized Internlm-chat-7B model by running the following command: +The default port of `api_server` is `23333`. After the server is launched, you can communicate with server on terminal through `api_client`: ```shell -lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8 +lmdeploy serve api_client http://0.0.0.0:23333 ``` -After saving, you can instantiate your quantized model by calling the from_pretrained interface. +You can overview and try out `api_server` APIs online by swagger UI at `http://0.0.0.0:23333`, or you can also read the API specification from [here](../llm/api_server.md). diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index 469ece487f..cb9805bb0b 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -4,97 +4,107 @@ The following tables detail the models supported by LMDeploy's TurboMind engine ## TurboMind on CUDA Platform -| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 | -| :-------------------: | :------------: | :--: | :-------: | :-----: | :-----: | :---: | -| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | -| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | -| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | -| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | -| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | -| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | -| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | -| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | -| InternLM-XComposer2 | 7B, 4khd-7B | MLLM | Yes | Yes | Yes | Yes | -| InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes | -| Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | -| Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes | -| Qwen2 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes | -| Mistral | 7B | LLM | Yes | Yes | Yes | No | -| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes | -| Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes | -| DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes | -| Baichuan | 7B | LLM | Yes | Yes | Yes | Yes | -| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | -| Code Llama | 7B - 34B | LLM | Yes | Yes | Yes | No | -| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | -| LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes | -| InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes | -| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes | Yes | Yes | -| ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes | -| MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes | -| MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes | -| MiniGeminiLlama | 7B | MLLM | Yes | - | - | Yes | -| GLM4 | 9B | LLM | Yes | Yes | Yes | Yes | -| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | -| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No | +| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 | +| :------------------------------: | :--------------: | :--: | :-------: | :-----: | :-----: | :---: | +| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | +| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | +| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | +| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | +| Llama3.2\[2\] | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes | +| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | +| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | +| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | +| InternLM-XComposer2 | 7B, 4khd-7B | MLLM | Yes | Yes | Yes | Yes | +| InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes | +| Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | +| Qwen1.5\[1\] | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes | +| Qwen2\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes | +| Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes | +| Qwen2.5\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes | +| Mistral\[1\] | 7B | LLM | Yes | Yes | Yes | No | +| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes | +| DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No | +| DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No | +| Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes | +| DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes | +| Baichuan | 7B | LLM | Yes | Yes | Yes | Yes | +| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | +| Code Llama | 7B - 34B | LLM | Yes | Yes | Yes | No | +| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | +| LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes | +| InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes | +| InternVL2\[2\] | 1 - 2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes | +| InternVL2.5(MPO)\[2\] | 1 - 78B | MLLM | Yes | Yes\* | Yes\* | Yes | +| ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes | +| MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes | +| MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes | +| MiniGeminiLlama | 7B | MLLM | Yes | - | - | Yes | +| GLM4 | 9B | LLM | Yes | Yes | Yes | Yes | +| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | +| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No | "-" means not verified yet. ```{note} -The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference. +* [1] The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference. +* [2] When the head_dim of a model is not 128, such as llama3.2-1B, qwen2-0.5B and internvl2-1B, turbomind doesn't support its kv cache 4/8 bit quantization and inference ``` ## PyTorchEngine on CUDA Platform -| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 | -| :------------: | :---------: | :--: | :-------: | :-----: | :-----: | :--: | :---: | -| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | Yes | -| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes | -| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes | -| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes | -| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | Yes | -| Llama3.2-VL | 11B, 90B | MLLM | Yes | Yes | Yes | - | - | -| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | -| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | -| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes | -| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | No | -| Baichuan2 | 13B | LLM | Yes | Yes | Yes | No | No | -| ChatGLM2 | 6B | LLM | Yes | Yes | Yes | No | No | -| Falcon | 7B - 180B | LLM | Yes | Yes | Yes | No | No | -| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | Yes | -| Mistral | 7B | LLM | Yes | Yes | Yes | Yes | Yes | -| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | No | No | -| QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | Yes | -| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes | -| QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No | -| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | -| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | No | -| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | -| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | -| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No | -| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | Yes | Yes | -| Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No | -| Dbrx | 132B | LLM | Yes | Yes | Yes | No | No | -| StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No | -| Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes | -| Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | - | - | -| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - | -| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - | -| LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | - | - | -| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | Yes | Yes | -| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | - | - | -| Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | - | - | -| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - | -| Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | - | - | -| GLM4 | 9B | LLM | Yes | Yes | Yes | No | No | -| GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | No | -| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | - | -| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - | -| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - | -| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - | +| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 | +| :----------------------------: | :---------: | :--: | :-------: | :-----: | :-----: | :--: | :---: | +| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | Yes | +| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes | +| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes | +| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes | +| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | Yes | +| Llama3.2-VL | 11B, 90B | MLLM | Yes | Yes | Yes | - | - | +| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | +| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | +| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes | +| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | No | +| Baichuan2 | 13B | LLM | Yes | Yes | Yes | No | No | +| ChatGLM2 | 6B | LLM | Yes | Yes | Yes | No | No | +| Falcon | 7B - 180B | LLM | Yes | Yes | Yes | No | No | +| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | Yes | +| Mistral | 7B | LLM | Yes | Yes | Yes | Yes | Yes | +| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | No | No | +| QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | Yes | +| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes | +| QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No | +| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | +| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | +| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes | +| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | +| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | +| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No | +| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No | +| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes | +| Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No | +| Dbrx | 132B | LLM | Yes | Yes | Yes | No | No | +| StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No | +| Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes | +| Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | - | - | +| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - | +| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - | +| LLaVA(1.5,1.6)\[2\] | 7B-34B | MLLM | No | No | No | No | No | +| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes | +| InternVL2 | 1B-76B | MLLM | Yes | Yes | Yes | - | - | +| InternVL2.5(MPO) | 1B-78B | MLLM | Yes | Yes | Yes | - | - | +| Mono-InternVL\[1\] | 2B | MLLM | Yes | Yes | Yes | - | - | +| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - | +| Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | - | - | +| GLM4 | 9B | LLM | Yes | Yes | Yes | No | No | +| GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | Yes | +| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | - | +| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - | +| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - | +| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - | ```{note} -* Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead. +* [1] Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead. +* [2] PyTorch engine removes the support of original llava models after v0.6.4. Please use their corresponding transformers models instead, which can be found in https://huggingface.co/llava-hf ``` ## PyTorchEngine on Huawei Ascend Platform diff --git a/docs/zh_cn/advance/pytorch_multithread.md b/docs/zh_cn/advance/pytorch_multithread.md new file mode 100644 index 0000000000..ebd68f503e --- /dev/null +++ b/docs/zh_cn/advance/pytorch_multithread.md @@ -0,0 +1,78 @@ +# PyTorchEngine 多线程推理 + +自 [PR2907](https://github.com/InternLM/lmdeploy/pull/2907) 起,我们废除了 PytorchEngine 的 thread_safe 模式以保证引擎能够更高效的运行。我们鼓励用户尽可能使用**服务接口**或**协程**来实现高并发,比如: + +```python +import asyncio +from lmdeploy import pipeline, PytorchEngineConfig + +event_loop = asyncio.new_event_loop() +asyncio.set_event_loop(event_loop) + +model_path = 'Llama-3.2-1B-Instruct' +pipe = pipeline(model_path, backend_config=PytorchEngineConfig()) + +async def _gather_output(): + tasks = [ + pipe.async_batch_infer('Hakuna Matata'), + pipe.async_batch_infer('giraffes are heartless creatures'), + ] + return await asyncio.gather(*tasks) + +output = asyncio.run(_gather_output()) +print(output[0].text) +print(output[1].text) +``` + +如果你确实有多线程推理的需求,那么可以进行简单的封装,来实现类似的效果。 + +```python +import threading +from queue import Queue +import asyncio +from lmdeploy import pipeline, PytorchEngineConfig + +model_path = 'Llama-3.2-1B-Instruct' + + +async def _batch_infer(inque: Queue, outque: Queue, pipe): + while True: + if inque.empty(): + await asyncio.sleep(0) + continue + + input = inque.get_nowait() + output = await pipe.async_batch_infer(input) + outque.put(output) + + +def server(inques, outques): + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + pipe = pipeline(model_path, backend_config=PytorchEngineConfig()) + for inque, outque in zip(inques, outques): + event_loop.create_task(_batch_infer(inque, outque, pipe)) + event_loop.run_forever() + +def client(inque, outque, message): + inque.put(message) + print(outque.get().text) + + +inques = [Queue(), Queue()] +outques = [Queue(), Queue()] + +t_server = threading.Thread(target=server, args=(inques, outques)) +t_client0 = threading.Thread(target=client, args=(inques[0], outques[0], 'Hakuna Matata')) +t_client1 = threading.Thread(target=client, args=(inques[1], outques[1], 'giraffes are heartless creatures')) + +t_server.start() +t_client0.start() +t_client1.start() + +t_client0.join() +t_client1.join() +``` + +> \[!WARNING\] +> 我们不鼓励这样实现,多线程会带来额外的开销,使得推理性能不稳定 diff --git a/docs/zh_cn/get_started/ascend/get_started.md b/docs/zh_cn/get_started/ascend/get_started.md index b137c458be..e4790253cd 100644 --- a/docs/zh_cn/get_started/ascend/get_started.md +++ b/docs/zh_cn/get_started/ascend/get_started.md @@ -17,7 +17,7 @@ cd lmdeploy ### 环境准备 -Docker 版本应不低于 18.03。并且需按照[官方指南](https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/clusterscheduling/clusterschedulingig/clusterschedulingig/dlug_installation_012.html)安装 Ascend Docker Runtime。 +Docker 版本应不低于 18.09。并且需按照[官方指南](https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/clusterscheduling/clusterschedulingig/clusterschedulingig/dlug_installation_012.html)安装 Ascend Docker Runtime。 > \[!CAUTION\] > 如果在后续容器内出现`libascend_hal.so: cannot open shared object file`错误,说明Ascend Docker Runtime没有被正确安装。 @@ -133,3 +133,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu ``` 支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)。 + +### int8 KV-cache 量化 + +昇腾后端现在支持了在eager模式下的离线int8 KV-cache量化。 + +详细使用方式请请参考这篇[文章](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md)。 diff --git a/docs/zh_cn/get_started/installation.md b/docs/zh_cn/get_started/installation.md index 12562c51d5..501f8a13e8 100644 --- a/docs/zh_cn/get_started/installation.md +++ b/docs/zh_cn/get_started/installation.md @@ -23,7 +23,7 @@ pip install lmdeploy 默认的预构建包是在 **CUDA 12** 上编译的。如果需要 CUDA 11+ (>=11.3),你可以使用以下命令安装 lmdeploy: ```shell -export LMDEPLOY_VERSION=0.6.3 +export LMDEPLOY_VERSION=0.6.5 export PYTHON_VERSION=38 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 ``` diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index 018a00487f..197e800d58 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -104,6 +104,7 @@ LMDeploy 工具箱提供以下核心功能: advance/chat_template.md advance/debug_turbomind.md advance/structed_output.md + advance/pytorch_multithread.md .. toctree:: :maxdepth: 1 diff --git a/docs/zh_cn/llm/api_server.md b/docs/zh_cn/llm/api_server.md index d6c0c42aef..8bb91c619e 100644 --- a/docs/zh_cn/llm/api_server.md +++ b/docs/zh_cn/llm/api_server.md @@ -258,6 +258,89 @@ curl http://{server_ip}:{server_port}/v1/chat/interactive \ }' ``` +## 同时启动多个 api_server + +两步直接启动多机多卡服务。先用下面的代码创建一个启动脚本。然后: + +1. 启动代理服务 `lmdeploy serve proxy`。 +2. torchrun 启动脚本 `torchrun --nproc_per_node 2 script.py InternLM/internlm2-chat-1_8b --proxy_url http://{proxy_node_name}:{proxy_node_port}`. **注意**: 多机多卡不要用默认 url `0.0.0.0:8000`,我们需要输入真实ip对应的地址,如:`11.25.34.55:8000`。多机情况下,因为不需要子节点间的通信,所以并不需要用户指定 torchrun 的 `--nnodes` 等参数,只要能保证每个节点执行一次单节点的 torchrun 就行。 + +```python +import os +import socket +from typing import List, Literal + +import fire + + +def get_host_ip(): + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(('8.8.8.8', 80)) + ip = s.getsockname()[0] + finally: + s.close() + return ip + + +def main(model_path: str, + tp: int = 1, + proxy_url: str = 'http://0.0.0.0:8000', + port: int = 23333, + backend: Literal['turbomind', 'pytorch'] = 'turbomind'): + local_rank = int(os.environ.get('LOCAL_RANK', -1)) + world_size = int(os.environ.get('WORLD_SIZE', -1)) + local_ip = get_host_ip() + if isinstance(port, List): + assert len(port) == world_size + port = port[local_rank] + else: + port += local_rank * 10 + if (world_size - local_rank) % tp == 0: + rank_list = ','.join([str(local_rank + i) for i in range(tp)]) + command = f'CUDA_VISIBLE_DEVICES={rank_list} lmdeploy serve api_server {model_path} '\ + f'--server-name {local_ip} --server-port {port} --tp {tp} '\ + f'--proxy-url {proxy_url} --backend {backend}' + print(f'running command: {command}') + os.system(command) + + +if __name__ == '__main__': + fire.Fire(main) +``` + +### 示例 + +为了进一步展示如何在集群环境中使用多机多卡服务。下面提供一个在火山云的用例: + +```shell +#!/bin/bash +# 激活 conda 环境 +source /path/to/your/home/miniconda3/bin/activate /path/to/your/home/miniconda3/envs/your_env +export HOME=/path/to/your/home +# 获取主节点IP地址(假设 MLP_WORKER_0_HOST 是主节点的IP) +MASTER_IP=${MLP_WORKER_0_HOST} +# 检查是否为主节点 +if [ "${MLP_ROLE_INDEX}" -eq 0 ]; then + # 启动 lmdeploy serve proxy 并放入后台 + echo "Starting lmdeploy serve proxy on master node..." + PROXY_PORT=8000 + lmdeploy serve proxy --server-name ${MASTER_IP} --server-port ${PROXY_PORT} & +else + # 这里我们默认调度平台同时启动了所有机器,否则要sleep一会,等待 proxy 启动成功 + echo "Not starting lmdeploy serve proxy on worker node ${MLP_ROLE_INDEX}." +fi +# 启动 torchrun 并放入后台 +# 再次强调多机环境下并不需要传--nnodes 或者 --master-addr 等参数,相当于每个机器上执行一次单节点的 torchrun 即可。 +torchrun \ +--nproc_per_node=${MLP_WORKER_GPU} \ +/path/to/script.py \ +InternLM/internlm2-chat-1_8b 8 http://${MASTER_IP}:${PROXY_PORT} +# 打印主机的IP地址 +echo "Host IP addresses:" +hostname -I +``` + ## 接入 WebUI LMDeploy 提供 gradio 和 [OpenAOE](https://github.com/InternLM/OpenAOE) 两种方式,为 api_server 接入 WebUI。 diff --git a/docs/zh_cn/multi_modal/llava.md b/docs/zh_cn/multi_modal/llava.md index c40f37308a..6538d1b861 100644 --- a/docs/zh_cn/multi_modal/llava.md +++ b/docs/zh_cn/multi_modal/llava.md @@ -6,11 +6,17 @@ LMDeploy 支持以下 LLaVA 系列模型,具体如下表所示: | :----------------------------------: | :--: | :----------------: | | llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch | | llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch | -| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind, PyTorch | -| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind, PyTorch | +| llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch | +| llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch | +| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind | +| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind | 接下来的章节将演示如何使用 LMDeploy 部署 LLaVA 模型,并以 [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) 为例。 +```{note} +自 0.6.4 之后,PyTorch 引擎移除了对 llava 原始模型的支持。我们建议使用它们对应的 transformers 格式的模型。这些模型可以在 https://huggingface.co/llava-hf 中找到 +``` + ## 安装 请按照[安装指南](../get_started/installation.md)安装 LMDeploy。 diff --git a/docs/zh_cn/multi_modal/qwen2_vl.md b/docs/zh_cn/multi_modal/qwen2_vl.md index f62d2de74c..7cb7efe93b 100644 --- a/docs/zh_cn/multi_modal/qwen2_vl.md +++ b/docs/zh_cn/multi_modal/qwen2_vl.md @@ -4,7 +4,7 @@ LMDeploy 支持 Qwen-VL 系列模型,具体如下: | Model | Size | Supported Inference Engine | | :----------: | :----: | :------------------------: | -| Qwen-VL-Chat | - | TurboMind, Pytorch | +| Qwen-VL-Chat | - | TurboMind | | Qwen2-VL | 2B, 7B | PyTorch | 本文将以[Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)为例,演示使用 LMDeploy 部署 Qwen2-VL 系列模型的方法 diff --git a/docs/zh_cn/quantization/w4a16.md b/docs/zh_cn/quantization/w4a16.md index d69a8a23d2..3cea164dd9 100644 --- a/docs/zh_cn/quantization/w4a16.md +++ b/docs/zh_cn/quantization/w4a16.md @@ -131,3 +131,8 @@ lmdeploy serve api_client http://0.0.0.0:23333 | ---------------- | ------- | ------- | --------- | | Llama-2-7B-chat | 112.9 | 159.4 | 206.4 | | Llama-2-13B-chat | N/A | 90.7 | 115.8 | + +## 快速问答 + +1. 量化时出现 Out of Memory 显存不够:可以通过减小传参 `--calib-seqlen`,增大传参 `--calib-samples`,并使用 `--batch-size` 为 1。 +2. 量化时,无法链接huggingface并下载数据集。可以尝试使用镜像,`export HF_ENDPOINT=https://hf-mirror.com`。 diff --git a/docs/zh_cn/quantization/w8a8.md b/docs/zh_cn/quantization/w8a8.md index 302dd538fd..3a63c82f8c 100644 --- a/docs/zh_cn/quantization/w8a8.md +++ b/docs/zh_cn/quantization/w8a8.md @@ -1,56 +1,76 @@ # W8A8 LLM 模型部署 -LMDeploy 提供了使用 8 bit 整数对神经网络模型进行量化和推理的功能。 +LMDeploy 提供了使用 8-bit 整数(INT8)和浮点数(FP8)对神经网络模型进行量化和推理的功能。 -在开始推理前,需要确保已经正确安装了 lmdeploy 和 openai/triton。可以通过以下命令进行安装: +可用于 INT8 和 FP8 推理的 NVIDIA GPU 分别为: + +- INT8 + - V100(sm70): V100 + - Turing(sm75): 20 series, T4 + - Ampere(sm80,sm86): 30 series, A10, A16, A30, A100 + - Ada Lovelace(sm89): 40 series + - Hopper(sm90): H100 +- FP8 + - Ada Lovelace(sm89): 40 series + - Hopper(sm90): H100 + +首先,执行如下命令安装lmdeploy: ```shell -pip install lmdeploy -pip install triton>=2.1.0 +pip install lmdeploy[all] ``` -## 8bit 权重模型推理 +## 8-bit 权重量化 -如果你需要进行 8 bit 权重模型推理,可以直接从 LMDeploy 的 [model zoo](https://huggingface.co/lmdeploy) 下载已经量化好的 8bit 权重模型。以8bit 的 Internlm-chat-7B 模型为例,可以从 model zoo 直接下载: +进行 8-bit 权重量化需要经历以下三步: -```shell -git-lfs install -git clone https://huggingface.co/lmdeploy/internlm-chat-7b-w8 (coming soon) -``` +1. **权重平滑**:首先对语言模型的权重进行平滑处理,以便更好地进行量化。 +2. **模块替换**:使用 `QRMSNorm` 和 `QLinear` 模块替换原模型 `DecoderLayer` 中的 `RMSNorm` 模块和 `nn.Linear` 模块。`lmdeploy/pytorch/models/q_modules.py` 文件中定义了这些量化模块。 +3. **保存量化模型**:完成上述必要的替换后,我们即可保存新的量化模型。 -你也可以参考["8bit 权重量化"](#8bit-权重量化)章节的内容手动将原 16bit 权重量化为 8bit,并保存至 `internlm-chat-7b-w8` 目录下,操作命令如下: +lmdeploy 提供了命令行工具 `lmdeploy lite smooth_quant` 实现了以上三个步骤。并且其中命令行参数 `--quant-dtype` 可以用来控制是进行8-bit整数还是浮点数类型的量化。更多命令行工具使用方式,请执行 `lmdeploy lite smooth_quant --help` 查看。 -```shell -lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8 -``` +以下示例演示了进行 int8 或 fp8 的量化命令。 -然后,执行以下命令,即可在终端与模型对话: +- int8 -```shell -lmdeploy chat ./internlm-chat-7b-w8 --backend pytorch -``` + ```shell + lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-int8 --quant-dtype int8 + ``` -## 启动 gradio 服务 +- fp8 -Coming soon... + ```shell + lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-fp8 --quant-dtype fp8 + ``` -## 推理速度 +## 模型推理 -Coming soon... +量化后的模型,通过以下几行简单的代码,可以实现离线推理: -## 8bit 权重量化 +```python +from lmdeploy import pipeline, PytorchEngineConfig -进行 8bit 权重量化需要经历以下三步: +engine_config = PytorchEngineConfig(tp=1) +pipe = pipeline("internlm2_5-7b-chat-int8", backend_config=engine_config) +response = pipe(["Hi, pls intro yourself", "Shanghai is"]) +print(response) +``` -1. **权重平滑**:首先对语言模型的权重进行平滑处理,以便更好地进行量化。 -2. **模块替换**:使用 `QRSMNorm` 和 `QLinear` 模块替换原模型 `DecoderLayer` 中的 `RSMNorm` 模块和 `nn.Linear` 模块。`lmdeploy/pytorch/models/q_modules.py` 文件中定义了这些量化模块。 -3. **保存量化模型**:完成上述必要的替换后,我们即可保存新的量化模型。 +关于 pipeline 的详细介绍,请参考[这里](../llm/pipeline.md) -我们在`lmdeploy/lite/api/smooth_quantity.py`脚本中已经实现了以上三个步骤。例如,可以通过以下命令得到量化后的 Internlm-chat-7B 模型的模型权重: +## 推理服务 + +LMDeploy `api_server` 支持把模型一键封装为服务,对外提供的 RESTful API 兼容 openai 的接口。以下为服务启动的示例: ```shell +lmdeploy serve api_server ./internlm2_5-7b-chat-int8 --backend pytorch +``` -lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8 +服务默认端口是23333。在 server 启动后,你可以在终端通过`api_client`与server进行对话: + +```shell +lmdeploy serve api_client http://0.0.0.0:23333 ``` -保存之后,你就可以通过调用from_pretrained接口来实例化你的量化模型。 +还可以通过 Swagger UI `http://0.0.0.0:23333` 在线阅读和试用 `api_server` 的各接口,也可直接查阅[文档](../llm/api_server.md),了解各接口的定义和使用方法。 diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index d734523282..83b7a9ca6f 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -4,97 +4,107 @@ ## TurboMind CUDA 平台 -| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 | -| :-------------------: | :------------: | :--: | :-------: | :-----: | :-----: | :---: | -| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | -| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | -| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | -| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | -| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | -| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | -| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | -| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | -| InternLM-XComposer2 | 7B, 4khd-7B | MLLM | Yes | Yes | Yes | Yes | -| InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes | -| Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | -| Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes | -| Qwen2 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes | -| Mistral | 7B | LLM | Yes | Yes | Yes | No | -| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes | -| Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes | -| DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes | -| Baichuan | 7B | LLM | Yes | Yes | Yes | Yes | -| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | -| Code Llama | 7B - 34B | LLM | Yes | Yes | Yes | No | -| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | -| LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes | -| InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes | -| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes | Yes | Yes | -| ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes | -| MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes | -| MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes | -| MiniGeminiLlama | 7B | MLLM | Yes | - | - | Yes | -| GLM4 | 9B | LLM | Yes | Yes | Yes | Yes | -| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | -| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No | +| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 | +| :------------------------------: | :------------: | :--: | :-------: | :-----: | :-----: | :---: | +| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | +| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | +| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | +| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | +| Llama3.2\[2\] | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes | +| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | +| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | +| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | +| InternLM-XComposer2 | 7B, 4khd-7B | MLLM | Yes | Yes | Yes | Yes | +| InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes | +| Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | +| Qwen1.5\[1\] | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes | +| Qwen2\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes | +| Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes | +| Qwen2.5\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes | +| Mistral\[1\] | 7B | LLM | Yes | Yes | Yes | No | +| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes | +| DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No | +| DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No | +| Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes | +| DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes | +| Baichuan | 7B | LLM | Yes | Yes | Yes | Yes | +| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | +| Code Llama | 7B - 34B | LLM | Yes | Yes | Yes | No | +| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | +| LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes | +| InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes | +| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes | +| InternVL2.5(MPO)\[2\] | 1 - 78B | MLLM | Yes | Yes\* | Yes\* | Yes | +| ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes | +| MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes | +| MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes | +| MiniGeminiLlama | 7B | MLLM | Yes | - | - | Yes | +| GLM4 | 9B | LLM | Yes | Yes | Yes | Yes | +| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | +| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No | “-” 表示还没有验证。 ```{note} -turbomind 引擎不支持 window attention。所以,对于应用了 window attention,并开启了对应的开关"use_sliding_window"的模型,比如 Mistral、Qwen1.5 等,在推理时,请选择 pytorch engine +* [1] turbomind 引擎不支持 window attention。所以,对于应用了 window attention,并开启了对应的开关"use_sliding_window"的模型,比如 Mistral、Qwen1.5 等,在推理时,请选择 pytorch engine +* [2] 当模型的 head_dim 非 128 时,turbomind 不支持它的 kv cache 4/8 bit 量化和推理。比如,llama3.2-1B,qwen2-0.5B,internvl2-1B 等等 ``` ## PyTorchEngine CUDA 平台 -| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 | -| :------------: | :---------: | :--: | :-------: | :-----: | :-----: | :--: | :---: | -| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | Yes | -| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes | -| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes | -| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes | -| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | Yes | -| Llama3.2-VL | 11B, 90B | MLLM | Yes | Yes | Yes | - | - | -| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | -| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | -| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes | -| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | No | -| Baichuan2 | 13B | LLM | Yes | Yes | Yes | No | No | -| ChatGLM2 | 6B | LLM | Yes | Yes | Yes | No | No | -| Falcon | 7B - 180B | LLM | Yes | Yes | Yes | No | No | -| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | Yes | -| Mistral | 7B | LLM | Yes | Yes | Yes | Yes | Yes | -| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | No | No | -| QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | Yes | -| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes | -| QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No | -| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | -| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | No | -| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | -| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | -| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No | -| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | Yes | Yes | -| Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No | -| Dbrx | 132B | LLM | Yes | Yes | Yes | No | No | -| StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No | -| Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes | -| Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | - | - | -| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - | -| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - | -| LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | - | - | -| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | Yes | Yes | -| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | - | - | -| Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | - | - | -| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - | -| Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | - | - | -| GLM4 | 9B | LLM | Yes | Yes | Yes | No | No | -| GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | No | -| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | - | -| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - | -| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - | -| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - | +| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 | +| :----------------------------: | :---------: | :--: | :-------: | :-----: | :-----: | :--: | :---: | +| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | Yes | +| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes | +| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes | +| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes | +| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | Yes | +| Llama3.2-VL | 11B, 90B | MLLM | Yes | Yes | Yes | - | - | +| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | +| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | +| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes | +| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | No | +| Baichuan2 | 13B | LLM | Yes | Yes | Yes | No | No | +| ChatGLM2 | 6B | LLM | Yes | Yes | Yes | No | No | +| Falcon | 7B - 180B | LLM | Yes | Yes | Yes | No | No | +| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | Yes | +| Mistral | 7B | LLM | Yes | Yes | Yes | Yes | Yes | +| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | No | No | +| QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | Yes | +| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes | +| QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No | +| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | +| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | +| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes | +| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | +| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | +| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No | +| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No | +| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes | +| Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No | +| Dbrx | 132B | LLM | Yes | Yes | Yes | No | No | +| StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No | +| Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes | +| Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | - | - | +| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - | +| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - | +| LLaVA(1.5,1.6)\[2\] | 7B-34B | MLLM | No | No | No | No | No | +| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes | +| InternVL2 | 1B-76B | MLLM | Yes | Yes | Yes | - | - | +| InternVL2.5(MPO) | 1B-78B | MLLM | Yes | Yes | Yes | - | - | +| Mono-InternVL\[1\] | 2B | MLLM | Yes\* | Yes | Yes | - | - | +| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - | +| Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | - | - | +| GLM4 | 9B | LLM | Yes | Yes | Yes | No | No | +| GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | Yes | +| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | - | +| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - | +| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - | +| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - | ```{note} -* Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead. +* [1] 目前,Mono-InternVL不支持FP16,因为数值不稳定。请改用BF16 +* [2] 自 0.6.4 之后,PyTorch 引擎移除了对 llava 模型原始格式的支持。我们建议使用它们对应的 transformers 格式的模型。这些模型可以在 https://huggingface.co/llava-hf 中找到 ``` ## PyTorchEngine 华为昇腾平台 diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py index ce5cbd98ff..760a82b1c9 100644 --- a/lmdeploy/archs.py +++ b/lmdeploy/archs.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os -from typing import Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union from transformers import AutoConfig @@ -128,7 +128,8 @@ def check_vl_llm(config: dict) -> bool: return True elif arch == 'MultiModalityCausalLM' and 'language_config' in config: return True - elif arch == 'ChatGLMModel' and 'vision_config' in config: + elif arch in ['ChatGLMModel', 'ChatGLMForConditionalGeneration' + ] and 'vision_config' in config: return True elif arch in supported_archs: return True @@ -193,3 +194,22 @@ def get_model_arch(model_path: str): raise RuntimeError( f'Could not find model architecture from config: {_cfg}') return arch, cfg + + +def search_nested_config(config, key): + """Recursively searches for the value associated with the given key in a + nested configuration of a model.""" + if isinstance(config, Dict): + for k, v in config.items(): + if k == key: + return v + if isinstance(v, (Dict, List)): + result = search_nested_config(v, key) + if result is not None: + return result + elif isinstance(config, List): + for item in config: + result = search_nested_config(item, key) + if result is not None: + return result + return None diff --git a/lmdeploy/cli/lite.py b/lmdeploy/cli/lite.py index d76d6a5f34..499bace485 100644 --- a/lmdeploy/cli/lite.py +++ b/lmdeploy/cli/lite.py @@ -35,6 +35,7 @@ def add_parser_auto_awq(): ArgumentHelper.calib_seqlen(parser) ArgumentHelper.calib_batchsize(parser) ArgumentHelper.calib_search_scale(parser) + ArgumentHelper.dtype(parser) parser.add_argument( '--device', type=str, @@ -71,6 +72,7 @@ def add_parser_auto_gptq(): ArgumentHelper.calib_samples(parser) ArgumentHelper.calib_seqlen(parser) ArgumentHelper.calib_batchsize(parser) + ArgumentHelper.dtype(parser) parser.add_argument('--w-bits', type=int, default=4, @@ -99,6 +101,7 @@ def add_parser_calibrate(): ArgumentHelper.calib_seqlen(parser) ArgumentHelper.calib_batchsize(parser) ArgumentHelper.calib_search_scale(parser) + ArgumentHelper.dtype(parser) @staticmethod def add_parser_smooth_quant(): @@ -122,6 +125,8 @@ def add_parser_smooth_quant(): ArgumentHelper.calib_seqlen(parser) ArgumentHelper.calib_batchsize(parser) ArgumentHelper.calib_search_scale(parser) + ArgumentHelper.dtype(parser) + ArgumentHelper.quant_dtype(parser) @staticmethod def auto_awq(args): diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index d4a0e54b1b..939d7a2f7b 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -239,6 +239,7 @@ def add_parser_proxy(): help='the strategy to dispatch requests to nodes') ArgumentHelper.api_keys(parser) ArgumentHelper.ssl(parser) + ArgumentHelper.log_level(parser) @staticmethod def gradio(args): diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 6db44930f4..33d5d339cf 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -122,6 +122,16 @@ def dtype(parser, default: str = 'auto'): 'for BF16 models. This option will be ignored if ' 'the model is a quantized model') + @staticmethod + def quant_dtype(parser, default: str = 'int8'): + return parser.add_argument( + '--quant-dtype', + type=str, + default=default, + choices=['int8', 'float8_e4m3fn', 'float8_e5m2', 'fp8'], + help='data type for the quantized model weights and activations.' + 'Note "fp8" is the short version of "float8_e4m3fn"') + @staticmethod def model_format(parser, default: str = None): return parser.add_argument( @@ -363,7 +373,7 @@ def calib_batchsize(parser): @staticmethod def calib_search_scale(parser): - """Add argument batch_size to parser.""" + """Add argument search_scale to parser.""" return parser.add_argument( '--search-scale', diff --git a/lmdeploy/lite/apis/auto_awq.py b/lmdeploy/lite/apis/auto_awq.py index c41b28fd6e..2c84612839 100644 --- a/lmdeploy/lite/apis/auto_awq.py +++ b/lmdeploy/lite/apis/auto_awq.py @@ -2,6 +2,7 @@ import os import os.path as osp import shutil +from typing import Literal import torch from torch import nn @@ -12,9 +13,7 @@ from lmdeploy.lite.utils import collect_target_modules from lmdeploy.pytorch.check_env import try_import_deeplink -from .calibrate import LAYER_TYPE_MAP, NORM_TYPE_MAP, calibrate - -NORM_TYPE_MAP = NORM_TYPE_MAP # legacy +from .calibrate import LAYER_TYPE_MAP, calibrate def save_vl_model(vl_model, model_path, dst_path): @@ -56,6 +55,7 @@ def auto_awq(model: str, search_scale: bool = False, device: str = 'cuda', revision: str = None, + dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto', download_dir: str = None): """Perform weight quantization using AWQ algorithm. @@ -77,6 +77,7 @@ def auto_awq(model: str, revision (str): The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. + dtype (str): Data type for loading model weights and calib infer. download_dir (str): Directory to download and load the weights, default to the default cache directory of huggingface. """ @@ -96,6 +97,7 @@ def auto_awq(model: str, w_bits=w_bits, w_group_size=w_group_size, search_scale=search_scale, + dtype=dtype, batch_size=batch_size) layer_type = LAYER_TYPE_MAP[type(model).__name__] diff --git a/lmdeploy/lite/apis/calibrate.py b/lmdeploy/lite/apis/calibrate.py index 71f7a5900c..007f831a70 100644 --- a/lmdeploy/lite/apis/calibrate.py +++ b/lmdeploy/lite/apis/calibrate.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from pathlib import Path -from typing import Union +from typing import Literal, Union import torch from torch import nn @@ -11,6 +11,7 @@ from lmdeploy.lite.quantization import CalibrationContext, CalibrationContextV2 from lmdeploy.lite.utils import (collect_target_modules, get_calib_loaders, load_hf_from_pretrained) +from lmdeploy.vl.model.builder import load_vl_model LAYER_TYPE_MAP = { 'InternLMForCausalLM': 'InternLMDecoderLayer', @@ -204,6 +205,7 @@ def calibrate(model: str, w_bits: int = 4, w_group_size: int = 128, search_scale: bool = False, + dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto', batch_size: int = 1) -> None: """The main function for loading the model and performing calibration on a given dataset. @@ -224,6 +226,7 @@ def calibrate(model: str, w_group_size (int): Group size for weight quantization statistics. search_scale (bool): Whether search scale ratio. Default to False, which means only smooth quant with 0.5 ratio will be applied. + dtype (str): Data type for loading model weights and calib infer. batch_size (int): The batch size for running the calib samples. Low GPU mem requires small batch_size. Large batch_size reduces the calibration time while costs more VRAM. @@ -239,20 +242,35 @@ def calibrate(model: str, model_type, _ = get_task(model) make_compatible_internvl_config(model) - if model_type == 'llm': - # Load tokenizer and configuration - tokenizer = AutoTokenizer.from_pretrained(model, - trust_remote_code=True) + # Load tokenizer and configuration + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + + if model_type == 'llm': model = load_hf_from_pretrained(model, - torch_dtype=torch.float16, + dtype=dtype, trust_remote_code=True) vl_model = None elif model_type == 'vlm': - from lmdeploy.vl.model.builder import vl_model_with_tokenizer - vl_model, model, tokenizer = vl_model_with_tokenizer(model_path=model) + vl_model = load_vl_model(model, backend=None, with_llm=True).vl_model + model = vl_model + if hasattr(vl_model, 'language_model'): # deepseek-vl, ... + model = vl_model.language_model + if hasattr(vl_model, 'llm'): # MiniCPMV, ... + model = vl_model.llm + model.config.use_cache = False + if dtype == 'float16': + model.half() + elif dtype == 'bfloat16': + assert torch.cuda.is_bf16_supported( + ), 'your device does not support bfloat16 please set --dtype float16' # noqa + model.to(torch.bfloat16) + elif dtype == 'auto' and model.config.torch_dtype == torch.bfloat16: + print('Warning: we cast model to float16 to prevent OOM. You' + ' may enforce it bfloat16 by `--dtype bfloat16`') + model.half() + model.eval() - model.config.use_cache = False model_type = type(model).__name__ if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP: raise RuntimeError( diff --git a/lmdeploy/lite/apis/gptq.py b/lmdeploy/lite/apis/gptq.py index 12b88a52cd..eb4418a533 100644 --- a/lmdeploy/lite/apis/gptq.py +++ b/lmdeploy/lite/apis/gptq.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging +from typing import Literal import torch -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from lmdeploy.lite.utils.calib_dataloader import get_calib_loaders @@ -15,6 +16,7 @@ def auto_gptq(model: str, calib_samples: int = 128, calib_seqlen: int = 2048, batch_size: int = 1, + dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto', revision: str = None): """Perform weight quantization using AWQ algorithm. @@ -29,9 +31,7 @@ def auto_gptq(model: str, calib_seqlen (int): The sequence length for calibration. w_bits (int): Bit number for weight quantization. w_group_size (int): Group size for weight quantization statistics. - search_scale (bool): Whether search scale ratio. Default to False, - which means only smooth quant with 0.5 ratio will be applied. - device (str): Device type of running. + dtype (str): Data type for loading model weights and calib infer. revision (str): The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. @@ -83,9 +83,18 @@ def auto_gptq(model: str, # load un-quantized model, by default, # the model will always be loaded into CPU memory + hf_config = AutoConfig.from_pretrained(pretrained_model_dir, + revision=revision, + trust_remote_code=True) + torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16) + if dtype == 'float16': + torch_dtype = torch.float16 + elif dtype == 'bfloat16': + torch_dtype = torch.bfloat16 model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, revision=revision, + torch_dtype=torch_dtype, trust_remote_code=True) # quantize model, the examples should be list of dict whose keys diff --git a/lmdeploy/lite/apis/smooth_quant.py b/lmdeploy/lite/apis/smooth_quant.py index c8df67355e..8d67535bcc 100644 --- a/lmdeploy/lite/apis/smooth_quant.py +++ b/lmdeploy/lite/apis/smooth_quant.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. + +from typing import Literal + import fire import torch from torch import nn @@ -6,7 +9,8 @@ from lmdeploy.lite.apis.calibrate import (LAYER_TYPE_MAP, NORM_TYPE_MAP, calibrate) from lmdeploy.lite.quantization.awq import (FC_FCS_MAP, NORM_FCS_MAP, - awq_layers, smooth_layers) + awq_layers, skipped_module, + smooth_layers) from lmdeploy.lite.utils import collect_target_modules from lmdeploy.pytorch.models import QLinear, QRMSNorm @@ -19,8 +23,20 @@ def smooth_quant(model: str, search_scale: bool = False, batch_size: int = 1, w_bits: int = 8, - device: str = 'cuda'): + dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto', + device: str = 'cuda', + quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn', + 'float8_e5m2'] = 'int8'): + if quant_dtype == 'fp8': + quant_dtype = 'float8_e4m3fn' + + quant_dtype = getattr(torch, quant_dtype, torch.int8) + if quant_dtype.is_floating_point: + q_dtype_info = torch.finfo(quant_dtype) + else: + q_dtype_info = torch.iinfo(quant_dtype) + assert q_dtype_info.bits == w_bits model_path = model vl_model, model, tokenizer, work_dir = calibrate(model, calib_dataset, @@ -31,6 +47,7 @@ def smooth_quant(model: str, w_bits=w_bits, w_group_size=-1, search_scale=search_scale, + dtype=dtype, batch_size=batch_size) # calibrate function exports the calibration statistics @@ -76,16 +93,20 @@ def smooth_quant(model: str, rmsnorms = collect_target_modules(model, norm_type) for name, linear in fcs.items(): + if skipped_module(name): + continue linear.to(device) - q_linear = QLinear.from_float(linear) + q_linear = QLinear.from_float(linear, quant_dtype=quant_dtype) parent_name, _, child_name = name.rpartition('.') parent = model.get_submodule(parent_name) setattr(parent, child_name, q_linear) linear.to('cpu') for name, norm in rmsnorms.items(): + if skipped_module(name): + continue norm.to(device) - q_norm = QRMSNorm.from_float(norm) + q_norm = QRMSNorm.from_float(norm, quant_dtype=quant_dtype) parent_name, _, child_name = name.rpartition('.') parent = model.get_submodule(parent_name) setattr(parent, child_name, q_norm) @@ -95,8 +116,10 @@ def smooth_quant(model: str, from .auto_awq import save_vl_model save_vl_model(vl_model, model_path, work_dir) else: + quant_dtype_s = str(quant_dtype).split('.')[1] model.config.update( - dict(quantization_config=dict(quant_method='smooth_quant'))) + dict(quantization_config=dict(quant_method='smooth_quant', + quant_dtype=f'{quant_dtype_s}'))) model.save_pretrained(work_dir, max_shard_size='2GB', safe_serialization=False) diff --git a/lmdeploy/lite/quantization/awq.py b/lmdeploy/lite/quantization/awq.py index cf03a75216..3e24a13cc3 100644 --- a/lmdeploy/lite/quantization/awq.py +++ b/lmdeploy/lite/quantization/awq.py @@ -43,8 +43,10 @@ 'MixtralDecoderLayer': { 'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], - 'post_attention_layernorm': - ['block_sparse_moe.experts.{i}.w1', 'block_sparse_moe.experts.{i}.w3'] + 'post_attention_layernorm': [ + 'block_sparse_moe.gate', 'block_sparse_moe.experts.{i}.w1', + 'block_sparse_moe.experts.{i}.w3' + ] }, 'Qwen2VLDecoderLayer': { 'input_layernorm': @@ -120,7 +122,12 @@ def get_weight_scale(weight, q_group_size=-1): org_shape = weight.shape if q_group_size > 0: weight = weight.view(-1, q_group_size) - scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) + abs_weight = weight.abs() + abs_weight_amax = abs_weight.amax(dim=1, keepdim=True) + if abs_weight_amax.min().item() == 0: + print('weight.amax.min is zero, clamping weight.amax to 1e-4') + abs_weight_amax = abs_weight_amax.clamp(min=1e-4) + scale = abs_weight / abs_weight_amax scale = scale.view(org_shape) scale = scale.mean(0) return scale @@ -153,8 +160,13 @@ def smooth_ln_fcs(ln: torch.nn.Module, concat_w = torch.cat([fc.weight for fc in fcs], dim=0) w_scales = get_weight_scale(concat_w, group_size) + w_scales_pow = w_scales.pow(1 - alpha) + if w_scales_pow.min().item() == 0: + print('w_scales.pow(1 - alpha).min is zero, ' + 'clamping w_scales.pow(1 - alpha) to 1e-4') + w_scales_pow = w_scales_pow.clamp(min=1e-4) scales = (act_scales.pow(alpha) / - w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype) + w_scales_pow).clamp(min=1e-4).to(device).to(dtype) scales = scales / (scales[nonzero_positions].max() * scales[nonzero_positions].min()).sqrt() @@ -204,8 +216,13 @@ def smooth_fc_fcs(pre_fc: torch.nn.Module, concat_w = torch.cat([fc.weight for fc in fcs], dim=0) w_scales = get_weight_scale(concat_w, group_size) + w_scales_pow = w_scales.pow(1 - alpha) + if w_scales_pow.min().item() == 0: + print('w_scales.pow(1 - alpha).min is zero, ' + 'clamping w_scales.pow(1 - alpha) to 1e-4') + w_scales_pow = w_scales_pow.clamp(min=1e-4) scales = (act_scales.pow(alpha) / - w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype) + w_scales_pow).clamp(min=1e-4).to(device).to(dtype) scales = scales / (scales.max() * scales.min()).sqrt() # (for qwen&baichuan) pre_fc is packed QKV, only V needs to scale diff --git a/lmdeploy/lite/quantization/calibration.py b/lmdeploy/lite/quantization/calibration.py index e590f1a4eb..1df8f2c740 100644 --- a/lmdeploy/lite/quantization/calibration.py +++ b/lmdeploy/lite/quantization/calibration.py @@ -42,6 +42,9 @@ def __init__(self, tokenizer (PreTrainedTokenizer): Tokenizer of the given model. layer_type (Union[str, type]): Type of the layers to be observed. norm_type (Union[str, type]): Norm type used in the model. + batch_size (int): The batch size for running the calib samples. + Low GPU mem requires small batch_size. Large batch_size + reduces the calibration time while costs more VRAM. device (str, optional): Device where the model should run. Defaults to 'cuda'. """ @@ -290,9 +293,14 @@ def _search_module_scale(block, linears2scale: list, x, kwargs={}): org_sd = {k: v.cpu() for k, v in block.state_dict().items()} for ratio in range(0, n_grid): - ratio = ratio * 1 / n_grid - scales = (x_max.pow(ratio) / - w_mean.pow(1 - ratio)).clamp(min=1e-4).view(-1) + ratio = ratio / n_grid + w_mean_pow = w_mean.pow(1 - ratio) + if w_mean_pow.min().item() == 0: + print('w_mean.pow(1 - ratio).min is zero, ' + 'clamping w_mean.pow(1 - ratio) to 1e-4') + w_mean_pow = w_mean_pow.clamp(min=1e-4) + scales = (x_max.pow(ratio) / w_mean_pow).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() for fc in linears2scale: fc.weight.mul_(scales.view(1, -1).to(fc.weight.device)) diff --git a/lmdeploy/lite/utils/load.py b/lmdeploy/lite/utils/load.py index bfd306a743..ac4519371a 100644 --- a/lmdeploy/lite/utils/load.py +++ b/lmdeploy/lite/utils/load.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Literal + import torch from transformers import AutoConfig, AutoModelForCausalLM @@ -7,29 +9,42 @@ def load_hf_from_pretrained(pretrained_model_name_or_path, - dtype=torch.float16, - **kwargs): + dtype: Literal['float16', 'bfloat16', + 'auto'], **kwargs): - if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + if dtype == 'bfloat16' and not torch.cuda.is_bf16_supported(): raise RuntimeError('Your device does not supports bf16(bfloat16), ' 'please change to fp16(float16)') kwargs.pop('config', None) hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, - torch_dtype=dtype, trust_remote_code=True) # HACK hard code for qwen, other configs do not have the `fp16` attribute. - if dtype == torch.float16: - hf_config.fp16 = True - elif dtype == torch.bfloat16: - hf_config.bf16 = True + if hasattr(hf_config, 'fp16') or hasattr(hf_config, 'bf16'): + if dtype == 'bfloat16': + hf_config.bf16 = True + else: + hf_config.fp16 = True + + torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16) + if dtype == 'bfloat16': + torch_dtype = torch.bfloat16 + elif dtype == 'float16': + torch_dtype = torch.float16 + elif dtype == 'auto' and torch_dtype == torch.bfloat16: + print('Warning: we cast model to float16 to prevent OOM. ' + 'You may enforce it bfloat16 by `--dtype bfloat16`') + torch_dtype = torch.float16 with LoadNoInit(): # Load model model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path, config=hf_config, **kwargs) + pretrained_model_name_or_path, + config=hf_config, + torch_dtype=torch_dtype, + **kwargs) model.config.use_cache = False return model diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 90823598ea..2336d10752 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -293,8 +293,11 @@ def __post_init__(self): assert self.device_type in [ 'cuda', 'ascend', 'maca' ], (f'invalid device_type: {self.device_type}') - if self.quant_policy > 0 and self.device_type != 'cuda': - assert False, 'kv cache quantization only works for CUDA.' + if self.quant_policy > 0 and self.device_type not in [ + 'cuda', 'ascend' + ]: + assert False, \ + 'kv cache quantization only works for CUDA and ASCEND.' class ResponseType(enum.Enum): diff --git a/lmdeploy/model.py b/lmdeploy/model.py index a4355ea131..f7b80ed102 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -46,6 +46,8 @@ class ChatTemplateConfig: eoh (str | None): end of the user prompt assistant (str | None): begin of the assistant prompt eoa (str | None): end of the assistant prompt + tool (str | None): begin of the tool prompt + eotool (str | None): end of the tool prompt capability: ('completion' | 'infilling' | 'chat' | 'python') = None """ # noqa: E501 @@ -57,6 +59,8 @@ class ChatTemplateConfig: eoh: Optional[str] = None assistant: Optional[str] = None eoa: Optional[str] = None + tool: Optional[str] = None + eotool: Optional[str] = None separator: Optional[str] = None capability: Optional[Literal['completion', 'infilling', 'chat', 'python']] = None @@ -173,6 +177,8 @@ def __init__(self, assistant='', eoa='', separator='', + tool='', + eotool='', **kwargs): super().__init__(**kwargs) self.system = system @@ -183,6 +189,8 @@ def __init__(self, self.separator = separator self.eosys = eosys self.assistant = assistant + self.tool = tool + self.eotool = eotool def get_prompt(self, prompt, sequence_start=True): """Return the prompt that is concatenated with other elements in the @@ -223,10 +231,12 @@ def messages2prompt(self, messages, sequence_start=True, **kwargs): return self.get_prompt(messages, sequence_start) box_map = dict(user=self.user, assistant=self.assistant, - system=self.system) + system=self.system, + tool=self.tool) eox_map = dict(user=self.eoh, assistant=self.eoa + self.separator, - system=self.eosys) + system=self.eosys, + tool=self.eotool) ret = '' if self.meta_instruction is not None and sequence_start: if len(messages) and messages[0]['role'] != 'system': @@ -819,7 +829,7 @@ class Llama3_1(Llama3): def __init__( self, - tools="""# Tool Instructions + tool="""# Tool Instructions - Always execute python code in messages that you share. - When looking for real time information use relevant functions if available else fallback to brave_search @@ -828,7 +838,7 @@ def __init__( You have access to the following functions: """, # noqa - eotools=""" + eotool=""" If a you choose to call a function ONLY reply in the following format: <{start_tag}={function_name}>{parameters}{end_tag} @@ -847,7 +857,7 @@ def __init__( - Only call one function at a time - Put the entire function call reply on one line" - Always add your sources when using search results to answer the user query\n\n""", # noqa - knowledge='Cutting Knowledge Date: December 2023\nToday Date: 23 Jul 2024\n\n', + knowledge='Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n', meta_instruction='You are a helpful assistant.', ipython='<|start_header_id|>ipython<|end_header_id|>\n\n', eoi='<|eot_id|>', @@ -858,8 +868,8 @@ def __init__( **kwargs) self.ipython = ipython self.eoi = eoi - self.tools = tools - self.eotools = eotools + self.tool = tool + self.eotool = eotool self.knowledge = knowledge def messages2prompt(self, @@ -899,7 +909,7 @@ def messages2prompt(self, if tools is None: ret += f'{self.system}{self.knowledge}{self.meta_instruction}{self.eosys}' else: - ret += f'{self.system}{self.knowledge}{self.tools}{tool_prompt}{self.eotools}{self.meta_instruction}{self.eosys}' + ret += f'{self.system}{self.knowledge}{self.tool}{tool_prompt}{self.eotool}{self.meta_instruction}{self.eosys}' for message in messages: role = message['role'] content = get_text(message['content']) @@ -907,7 +917,7 @@ def messages2prompt(self, or '' in content): ret += f'{box_map[role]}{content}<|eom_id|>' elif role == 'system' and tools is not None: - ret += f'{box_map[role]}{self.tools}{tool_prompt}{self.eotools}{content}{eox_map[role]}' + ret += f'{box_map[role]}{self.tool}{tool_prompt}{self.eotool}{content}{eox_map[role]}' else: ret += f'{box_map[role]}{content}{eox_map[role]}' if sequence_start and not isinstance(messages, str): @@ -1921,5 +1931,5 @@ def best_match_model(query: str) -> Optional[str]: for name, model in MODELS.module_dict.items(): if model.match(query): return model.match(query) - logger.warn(f'Did not find a chat template matching {query}.') + logger.warning(f'Did not find a chat template matching {query}.') return 'base' diff --git a/lmdeploy/pytorch/backends/attention.py b/lmdeploy/pytorch/backends/attention.py index 92a0befbf4..f0e60d86ac 100644 --- a/lmdeploy/pytorch/backends/attention.py +++ b/lmdeploy/pytorch/backends/attention.py @@ -34,6 +34,7 @@ def __init__( alibi: bool = None, sliding_window: int = None, logit_softcapping: float = None, + causal: bool = True, **kwargs, ) -> None: if scale is None: @@ -53,6 +54,7 @@ def __init__( self.alibi = alibi self.sliding_window = sliding_window self.logit_softcapping = logit_softcapping + self.causal = causal @abstractmethod def forward( @@ -82,6 +84,7 @@ def build( alibi: bool = False, sliding_window: int = None, logical_softcapping: float = None, + causal: bool = True, **kwargs, ) -> AttentionImpl[T]: """build.""" diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index ef538f7a3d..263b419f1a 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -12,7 +12,8 @@ class OpType(Enum): """Layer type enumerate.""" - Attention = auto() + PagedAttention = auto() + FlashAttention = auto() Linear = auto() RotaryEmbedding = auto() ApplyRotaryEmb = auto() @@ -27,6 +28,9 @@ class OpType(Enum): LinearW4A16 = auto() SoftmaxTopK = auto() FusedMoE = auto() + FusedMoEW8A8 = auto() + LinearBlockedF8 = auto() + FusedMoEBlockedF8 = auto() class OpsBackend(ABC): diff --git a/lmdeploy/pytorch/backends/blockedf8_modules.py b/lmdeploy/pytorch/backends/blockedf8_modules.py new file mode 100644 index 0000000000..d79b41330c --- /dev/null +++ b/lmdeploy/pytorch/backends/blockedf8_modules.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from typing import Optional + +import torch + + +class LinearBlockedF8Impl(ABC): + """linear BlockedF8 implementation api.""" + + def update_weights(self, + weight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """update weights.""" + return weight, scale, bias + + @abstractmethod + def forward(self, + x, + weight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): + """forward.""" + raise NotImplementedError + + +class LinearBlockedF8Builder(ABC): + """linear BlockedF8 implementation builder.""" + + @staticmethod + @abstractmethod + def build(in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None): + """build.""" + raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index 1672803ff4..31546ae0e1 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -42,6 +42,7 @@ def __init__( alibi: bool = False, sliding_window: int = None, logit_softcapping: float = None, + causal: bool = True, **kwargs, ): super().__init__( @@ -53,8 +54,10 @@ def __init__( alibi=alibi, sliding_window=sliding_window, logit_softcapping=logit_softcapping, + causal=causal, **kwargs, ) + assert not (alibi and not causal) from lmdeploy.pytorch.kernels.cuda import (alibi_paged_attention_fwd, fill_kv_cache, @@ -177,6 +180,7 @@ def forward( window_size=self.sliding_window, sm_scale=self.scale, logit_softcapping=self.logit_softcapping, + causal=self.causal, ) else: self.alibi_paged_attention_fwd( @@ -212,6 +216,7 @@ def build( alibi: bool = False, sliding_window: int = None, logical_softcapping: float = None, + causal: bool = True, **kwargs, ) -> TritonAttentionImpl: """build.""" @@ -223,4 +228,5 @@ def build( alibi=alibi, sliding_window=sliding_window, logical_softcapping=logical_softcapping, + causal=causal, **kwargs) diff --git a/lmdeploy/pytorch/backends/cuda/awq_modules.py b/lmdeploy/pytorch/backends/cuda/awq_modules.py index 8159bbf554..18b0150493 100644 --- a/lmdeploy/pytorch/backends/cuda/awq_modules.py +++ b/lmdeploy/pytorch/backends/cuda/awq_modules.py @@ -18,23 +18,14 @@ def wq_gemm_forward( out_features=0, ): """wq gemm forward.""" - from awq.modules.linear.gemm import awq_ext - from lmdeploy.pytorch.kernels.cuda.awq_kernels import awq_linear out_shape = x.shape[:-1] + (out_features, ) input_dtype = x.dtype if input_dtype != torch.float16: x = x.half() - FP16_MATMUL_HEURISTIC_CONDITION = x.size(0) * x.size(1) >= 64 - x = x.flatten(0, -2) - if FP16_MATMUL_HEURISTIC_CONDITION: - out = awq_linear(x, qweight, scales, qzeros) - else: - if not x.is_contiguous(): - x = x.contiguous() - out = awq_ext.gemm_forward_cuda(x, qweight, scales, qzeros, 8) + out = awq_linear(x, qweight, scales, qzeros) out = out + bias if bias is not None else out out = out.reshape(out_shape) diff --git a/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py b/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py new file mode 100644 index 0000000000..8299ac2dfd --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.distributed as dist + +from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import (blocked_gemm_fp8, + quant_fp8) + +from ..blockedf8_modules import LinearBlockedF8Builder, LinearBlockedF8Impl + + +class TritonLinearBlockedF8Impl(LinearBlockedF8Impl): + """triton linear blocked f8 implementation.""" + + def __init__(self, + in_features: int, + out_features: int, + block_size: int, + out_dtype: torch.dtype = torch.float16): + self.in_features = in_features + self.out_features = out_features + self.out_dtype = out_dtype + self.block_size = block_size + + def forward(self, + x, + weight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): + """forward.""" + x_shape = x.shape + x = x.flatten(0, -2) + input_quant, input_scale = quant_fp8(x, + self.block_size, + dtype=weight.dtype) + + out = blocked_gemm_fp8(input_quant, + input_scale, + weight.t(), + scale.t(), + out_dtype=x.dtype) + if bias is not None: + out += bias + + if all_reduce: + dist.all_reduce(out) + + out = out.unflatten(0, x_shape[:-1]) + return out + + +class TritonLinearBlockedF8Builder(LinearBlockedF8Builder): + """triton linear blocked f8 implementation builder.""" + + @staticmethod + def build(in_features: int, + out_features: int, + block_size: int = 128, + bias: bool = True, + dtype: torch.dtype = None): + """build.""" + return TritonLinearBlockedF8Impl(in_features, out_features, block_size, + dtype) diff --git a/lmdeploy/pytorch/backends/cuda/flash_attention.py b/lmdeploy/pytorch/backends/cuda/flash_attention.py new file mode 100644 index 0000000000..5d3925b744 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/flash_attention.py @@ -0,0 +1,101 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import Tensor + +from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl + + +class TritonFlashAttentionImpl(FlashAttentionImpl): + """triton flash attention implementation.""" + + def __init__( + self, + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logical_softcapping: float = None, + ): + if scale is None: + scale = 1.0 / (head_dim**0.5) + + if num_kv_heads is None: + num_kv_heads = num_heads + + if v_head_dim is None: + v_head_dim = head_dim + + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = scale + self.num_kv_heads = num_kv_heads + self.v_head_dim = v_head_dim + self.causal = causal + self.sliding_window = sliding_window + self.logical_softcapping = logical_softcapping + + from lmdeploy.pytorch.kernels.cuda import flash_attention_fwd + self.flash_attention_fwd = flash_attention_fwd + + def forward(self, + query: Tensor, + key: Tensor, + value: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_start_loc: Tensor, + kv_seqlens: Tensor, + max_q_seqlen: int = None): + """forward.""" + + q_shape = query.shape + o_shape = q_shape[:-1] + (self.v_head_dim, ) + out = query.new_empty(o_shape) + self.flash_attention_fwd( + query, + key, + value, + out, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_start_loc=kv_start_loc, + kv_seqlens=kv_seqlens, + max_seqlen=max_q_seqlen, + window_size=self.sliding_window, + sm_scale=self.scale, + logit_softcapping=self.logical_softcapping, + causal=self.causal, + kv_layout='shd', + ) + + return out + + +class TritonFlashAttentionBuilder(FlashAttentionBuilder): + """triton attention builder.""" + + @staticmethod + def build( + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logical_softcapping: float = None, + **kwargs, + ) -> FlashAttentionImpl: + """build.""" + return TritonFlashAttentionImpl( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + v_head_dim=v_head_dim, + causal=causal, + sliding_window=sliding_window, + logical_softcapping=logical_softcapping, + ) diff --git a/lmdeploy/pytorch/backends/cuda/lora.py b/lmdeploy/pytorch/backends/cuda/lora.py index 798d985715..b65a01df14 100644 --- a/lmdeploy/pytorch/backends/cuda/lora.py +++ b/lmdeploy/pytorch/backends/cuda/lora.py @@ -50,6 +50,15 @@ def forward(self, """forward.""" lora_input = self._make_packed_lora_input(x, ctx_mgr) + base_slice = adapter_info.base_slice + sliced_base = base_output[..., base_slice] + + if base_output.is_contiguous(): + kernel_output = sliced_base.flatten(0, -2) + cum = True + else: + kernel_output = None + cum = False lora_out = fused_lora( lora_input.x, lora_A, @@ -62,14 +71,14 @@ def forward(self, adapter_ids=lora_input.adapter_ids, max_rank=adapter_info.max_rank, max_seqlen=lora_input.max_seq_len, + output=kernel_output, + cum=cum, ) - base_slice = adapter_info.base_slice - sliced_base = base_output[..., base_slice] - lora_out = lora_out.reshape(sliced_base.shape) - sliced_base.add_(lora_out) - output = base_output - return output + if not base_output.is_contiguous(): + lora_out = lora_out.reshape(sliced_base.shape) + sliced_base.add_(lora_out) + return base_output class TritonLoRABuilder(LoRABuilder): diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index eb38401211..a913ca82fb 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -4,9 +4,17 @@ import torch -from lmdeploy.pytorch.kernels.cuda import fused_moe +from lmdeploy.pytorch.kernels.cuda import fused_moe, fused_moe_w8a8 +from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import \ + fused_moe_blocked_fp8 +from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8 +from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import \ + per_token_quant_int8 +from lmdeploy.pytorch.models.q_modules import QTensor -from ..moe import FusedMoEBuilder, FusedMoEImpl +from ..moe import (FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl, + FusedMoEBuilder, FusedMoEImpl, FusedMoEW8A8Builder, + FusedMoEW8A8Impl) class TritonFusedMoEImpl(FusedMoEImpl): @@ -74,3 +82,185 @@ def build(top_k: int, num_experts: int, renormalize: bool = False): return TritonFusedMoEImpl(top_k=top_k, num_experts=num_experts, renormalize=renormalize) + + +class TritonFusedMoEW8A8Impl(FusedMoEW8A8Impl): + """triton fused moe w8a8 implementation.""" + + def __init__(self, + top_k: int, + num_experts: int, + renormalize: bool = False, + out_dtype: torch.dtype = torch.float16): + self.num_experts = num_experts + self.top_k = top_k + self.renormalize = renormalize + self.out_dtype = out_dtype + + def update_weights(self, gate_up_weights: torch.Tensor, + down_weights: torch.Tensor, gate_up_scale: torch.Tensor, + down_scale: torch.Tensor): + gate_up_weights = gate_up_weights.transpose(1, + 2).contiguous().transpose( + 1, 2) + down_weights = down_weights.transpose(1, + 2).contiguous().transpose(1, 2) + return gate_up_weights, down_weights, gate_up_scale, down_scale + + def support_ep(self): + """support expert parallelism.""" + return True + + def ep_expert_list(self, world_size: int, rank: int): + """experts list of current rank.""" + num_experts = self.num_experts + expert_per_rank = (num_experts + world_size - 1) // world_size + first_expert = rank * expert_per_rank + last_expert = min(first_expert + expert_per_rank, num_experts) + return list(range(first_expert, last_expert)) + + def forward(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + gate_up_weights: torch.Tensor, + gate_up_scale: torch.Tensor, + down_weights: torch.Tensor, + down_scale: torch.Tensor, + expert_list: List[int] = None): + """forward.""" + + if isinstance(hidden_states, torch.Tensor): + hidden_states = hidden_states.contiguous() + input_quant, input_scale = per_token_quant_int8( + hidden_states, 1e-7) + else: + assert isinstance(hidden_states, QTensor) + input_quant, input_scale = (hidden_states.tensor, + hidden_states.scale) + + expert_offset = 0 + num_experts = None + if expert_list is not None and len(expert_list) != self.num_experts: + expert_offset = expert_list[0] + num_experts = self.num_experts + return fused_moe_w8a8(input_quant, + input_scale, + gate_up_weights, + gate_up_scale, + down_weights, + down_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + topk=self.top_k, + out_dtype=self.out_dtype, + expert_offset=expert_offset, + num_experts=num_experts, + renormalize=self.renormalize) + + +class TritonFusedMoEW8A8Builder(FusedMoEW8A8Builder): + """triton fused moe w8a8 builder.""" + + @staticmethod + def build(top_k: int, + num_experts: int, + renormalize: bool = False, + out_dtype: torch.dtype = torch.float16): + """build from mlp.""" + return TritonFusedMoEW8A8Impl(top_k=top_k, + num_experts=num_experts, + renormalize=renormalize, + out_dtype=out_dtype) + + +class TritonFusedMoEBlockedF8Impl(FusedMoEBlockedF8Impl): + """triton fused moe blocked f8 implementation.""" + + def __init__(self, + top_k: int, + num_experts: int, + renormalize: bool = False, + block_size: int = 128, + out_dtype: torch.dtype = torch.float16): + self.num_experts = num_experts + self.top_k = top_k + self.renormalize = renormalize + self.block_size = block_size + self.out_dtype = out_dtype + + def update_weights(self, gate_up_weights: torch.Tensor, + down_weights: torch.Tensor, gate_up_scale: torch.Tensor, + down_scale: torch.Tensor): + gate_up_weights = gate_up_weights.transpose(1, + 2).contiguous().transpose( + 1, 2) + down_weights = down_weights.transpose(1, + 2).contiguous().transpose(1, 2) + return gate_up_weights, down_weights, gate_up_scale, down_scale + + def support_ep(self): + """support expert parallelism.""" + return True + + def ep_expert_list(self, world_size: int, rank: int): + """experts list of current rank.""" + num_experts = self.num_experts + expert_per_rank = (num_experts + world_size - 1) // world_size + first_expert = rank * expert_per_rank + last_expert = min(first_expert + expert_per_rank, num_experts) + return list(range(first_expert, last_expert)) + + def forward(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + gate_up_weights: torch.Tensor, + gate_up_scale: torch.Tensor, + down_weights: torch.Tensor, + down_scale: torch.Tensor, + expert_list: List[int] = None): + """forward.""" + input_size = hidden_states.shape + hidden_states = hidden_states.flatten(0, -2) + input_quant, input_scale = quant_fp8(hidden_states, + self.block_size, + dtype=gate_up_weights.dtype) + + expert_offset = 0 + num_experts = None + if expert_list is not None and len(expert_list) != self.num_experts: + expert_offset = expert_list[0] + num_experts = self.num_experts + output = fused_moe_blocked_fp8(input_quant, + input_scale, + gate_up_weights, + gate_up_scale, + down_weights, + down_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + topk=self.top_k, + out_dtype=hidden_states.dtype, + expert_offset=expert_offset, + num_experts=num_experts, + renormalize=self.renormalize) + output = output.unflatten(0, input_size[:-1]) + return output + + +class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder): + """triton fused moe blocked f8 builder.""" + + @staticmethod + def build(top_k: int, + num_experts: int, + renormalize: bool = False, + block_size: int = 128, + out_dtype: torch.dtype = torch.float16): + """build from mlp.""" + return TritonFusedMoEBlockedF8Impl(top_k=top_k, + num_experts=num_experts, + renormalize=renormalize, + block_size=block_size, + out_dtype=out_dtype) diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index bfd77a250f..cbc46352a5 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -23,9 +23,12 @@ def get_name() -> str: @classmethod def get_layer_impl_builder(cls, layer_type: OpType): """get cuda layer builder.""" - if layer_type == OpType.Attention: + if layer_type == OpType.PagedAttention: from .attention import TritonAttentionBuilder return TritonAttentionBuilder + elif layer_type == OpType.FlashAttention: + from .flash_attention import TritonFlashAttentionBuilder + return TritonFlashAttentionBuilder elif layer_type == OpType.ApplyRotaryEmb: from .apply_rotary_emb import TritonApplyRotaryEmbBuilder return TritonApplyRotaryEmbBuilder @@ -48,21 +51,20 @@ def get_layer_impl_builder(cls, layer_type: OpType): from .activation import TritonSiluAndMulBuilder return TritonSiluAndMulBuilder elif layer_type == OpType.LinearW4A16: - try: - from awq.modules.linear.gemm import awq_ext # noqa: F401 - AWQ_INSTALLED = True - except Exception: - AWQ_INSTALLED = False - if AWQ_INSTALLED: - from .awq_modules import AwqLinearW4A16Builder - return AwqLinearW4A16Builder - else: - logger.debug( - f'Op {layer_type} fallback to default implementation.') - return super().get_layer_impl_builder(layer_type) + from .awq_modules import AwqLinearW4A16Builder + return AwqLinearW4A16Builder elif layer_type == OpType.FusedMoE: from .moe import TritonFusedMoEBuilder return TritonFusedMoEBuilder + elif layer_type == OpType.FusedMoEW8A8: + from .moe import TritonFusedMoEW8A8Builder + return TritonFusedMoEW8A8Builder + elif layer_type == OpType.FusedMoEBlockedF8: + from .moe import TritonFusedMoEBlockedF8Builder + return TritonFusedMoEBlockedF8Builder + elif layer_type == OpType.LinearBlockedF8: + from .blockedf8_modules import TritonLinearBlockedF8Builder + return TritonLinearBlockedF8Builder else: logger.debug( f'Op {layer_type} fallback to default implementation.') @@ -142,30 +144,30 @@ def update_step_context(cls, step_context): medusa_attn_mask=step_context.medusa_attn_mask, ) - cross_attn_metadata = None - fill_seqlens = None - if step_context.cross_attention_states is not None: - fill_seqlens = torch.zeros_like(q_seqlens) - for idx, state in enumerate(step_context.cross_attention_states): - if state is not None: - fill_seqlens[idx] = state.shape[-2] + cross_seqlens = step_context.cross_seqlens cross_kv_seqlens = step_context.cross_kv_seqlens - cross_kv_start_loc = None - cross_kv_flatten_size = None - if not step_context.is_decoding and cross_kv_seqlens is not None: - cross_kv_start_loc = cross_kv_seqlens.cumsum(0) - cross_kv_seqlens - cross_kv_flatten_size = cross_kv_seqlens.sum().item() - cross_attn_metadata = attn_meta_cls( - step_context.is_decoding, - step_context.block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seqlens, - kv_start_loc=cross_kv_start_loc, - kv_seqlens=cross_kv_seqlens, - kv_flatten_size=cross_kv_flatten_size, - fill_seqlens=fill_seqlens, - quant_policy=step_context.kv_quant_policy, - ) + cross_attn_metadata = None + if cross_seqlens is not None: + fill_seqlens = cross_seqlens + if fill_seqlens.sum().item() == 0: + fill_seqlens = None + cross_kv_start_loc = None + cross_kv_flatten_size = None + if not step_context.is_decoding and cross_kv_seqlens is not None: + cross_kv_start_loc = cross_kv_seqlens.cumsum( + 0) - cross_kv_seqlens + cross_kv_flatten_size = cross_kv_seqlens.sum().item() + cross_attn_metadata = attn_meta_cls( + step_context.is_decoding, + step_context.block_offsets, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_start_loc=cross_kv_start_loc, + kv_seqlens=cross_kv_seqlens, + kv_flatten_size=cross_kv_flatten_size, + fill_seqlens=fill_seqlens, + quant_policy=step_context.kv_quant_policy, + ) step_context.attn_metadata = attn_metadata step_context.cross_attn_metadata = cross_attn_metadata diff --git a/lmdeploy/pytorch/backends/cuda/qmodules.py b/lmdeploy/pytorch/backends/cuda/qmodules.py index 30f729a63f..13d9a47ddf 100644 --- a/lmdeploy/pytorch/backends/cuda/qmodules.py +++ b/lmdeploy/pytorch/backends/cuda/qmodules.py @@ -15,42 +15,62 @@ class TritonRMSNormW8A8Impl(RMSNormW8A8Impl): """triton RMS norm w8a8 implementation api.""" - def __init__(self, hidden_size: int, eps: float = 1e-6): + def __init__(self, + hidden_size: int, + eps: float = 1e-6, + quant_dtype: torch.dtype = torch.int8): super().__init__() self.hidden_size = hidden_size self.eps = eps + self.quant_dtype = quant_dtype def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None): """forward.""" - if residual is not None: - x = x + residual - residual = x - hidden_states_quant, rms_scale = rms_norm_dynamic_quant( - x, weight, self.eps) - x = QTensor(hidden_states_quant, rms_scale) if residual is None: + (x, + rms_scale) = rms_norm_dynamic_quant(x, + weight, + self.eps, + quant_dtype=self.quant_dtype) + x = QTensor(x, rms_scale) return x - return x, residual + else: + (x, rms_scale, + residual) = rms_norm_dynamic_quant(x, + weight, + self.eps, + residual=residual, + quant_dtype=self.quant_dtype) + x = QTensor(x, rms_scale) + return x, residual class TritonRMSNormBuilder(RMSNormW8A8Builder): """triton RMS norm w8a8 implementation builder.""" @staticmethod - def build(hidden_size: int, eps: float = 1e-6): + def build(hidden_size: int, + eps: float = 1e-6, + quant_dtype: torch.dtype = torch.int8): """build.""" - return TritonRMSNormW8A8Impl(hidden_size, eps) + return TritonRMSNormW8A8Impl(hidden_size, eps, quant_dtype) class TritonLinearW8A8Impl(LinearW8A8Impl): """triton linear w8a8 implementation.""" - def __init__(self, in_features: int, out_features: int): + def __init__(self, + in_features: int, + out_features: int, + out_dtype: torch.dtype = torch.float16, + quant_dtype: torch.dtype = torch.int8): self.in_features = in_features self.out_features = out_features + self.out_dtype = out_dtype + self.quant_dtype = quant_dtype def forward(self, x, @@ -60,8 +80,8 @@ def forward(self, all_reduce: bool = False): """forward.""" if isinstance(x, torch.Tensor): - x = x.contiguous() - input_quant, input_scale = per_token_quant_int8(x, 1e-7) + input_quant, input_scale = per_token_quant_int8( + x, 1e-7, quant_dtype=self.quant_dtype) else: assert isinstance(x, QTensor) input_quant, input_scale = x.tensor, x.scale @@ -70,7 +90,7 @@ def forward(self, weight, input_scale, scale, - output_dtype=torch.float16, + output_dtype=self.out_dtype, bias=bias) if all_reduce: @@ -85,6 +105,10 @@ class TritonLinearW8A8Builder(LinearW8A8Builder): def build(in_features: int, out_features: int, bias: bool = True, - dtype: torch.dtype = None): + dtype: torch.dtype = None, + quant_dtype: torch.dtype = torch.int8): """build.""" - return TritonLinearW8A8Impl(in_features, out_features) + return TritonLinearW8A8Impl(in_features, + out_features, + dtype, + quant_dtype=quant_dtype) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py index f9664f13ff..e3c5dc4d5e 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py @@ -33,10 +33,17 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, dlinfer.graph.config.enable_graph_mode = True self.patch_kernels_custom_op() self.patch_kvcache_static_shape() - self.model = torch.compile(self.model, - fullgraph=True, - dynamic=True, - backend='atbgraph') + if hasattr(self.model, 'language_model'): + self.model.language_model = torch.compile( + self.model.language_model, + fullgraph=True, + dynamic=True, + backend='atbgraph') + else: + self.model = torch.compile(self.model, + fullgraph=True, + dynamic=True, + backend='atbgraph') def check_enable_graph(self): """check enable graph.""" diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index b6f544510b..588558f0d5 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -1,5 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple +import itertools +import os +import re +from pathlib import Path +from typing import Dict, Tuple import torch @@ -11,6 +15,71 @@ logger = get_logger('lmdeploy') +class AscendKVQuantMeta: + has_set_value: bool = False + quant_meta: Dict = {} + + @classmethod + def set_value(cls, device: str, dtype: torch.dtype, record_file: str, + total_layers: int): + with open(record_file, 'r') as file: + data = file.read() + scale_offset_pairs = re.findall( + r'scale:\s*([\d\.\-]+)\s*offset:\s*(-?\d+)', data) + scale_offset_pairs = [(float(scale), float(offset)) + for scale, offset in scale_offset_pairs] + k_scales, v_scales, kv_scales = [], [], [] + k_zeros, v_zeros, kv_zeros = [], [], [] + if len(scale_offset_pairs) == total_layers: + for scale, offset in scale_offset_pairs: + k_scales.append( + torch.tensor([scale], device=device, dtype=dtype)) + v_scales.append( + torch.tensor([scale], device=device, dtype=dtype)) + kv_scales.append( + torch.tensor([scale, scale], device=device, dtype=dtype)) + k_zeros.append( + torch.tensor([offset], device=device, dtype=dtype)) + v_zeros.append( + torch.tensor([offset], device=device, dtype=dtype)) + kv_zeros.append( + torch.tensor([offset, offset], device=device, dtype=dtype)) + elif len(scale_offset_pairs) == total_layers * 2: + for i in range(total_layers): + scale_k, offset_k = scale_offset_pairs[2 * i] + scale_v, offset_v = scale_offset_pairs[2 * i + 1] + k_scales.append( + torch.tensor([scale_k], device=device, dtype=dtype)) + v_scales.append( + torch.tensor([scale_v], device=device, dtype=dtype)) + kv_scales.append( + torch.tensor([scale_k, scale_v], + device=device, + dtype=dtype)) + k_zeros.append( + torch.tensor([offset_k], device=device, dtype=dtype)) + v_zeros.append( + torch.tensor([offset_v], device=device, dtype=dtype)) + kv_zeros.append( + torch.tensor([offset_k, offset_v], + device=device, + dtype=dtype)) + else: + raise ValueError( + f'num of scale_offset_pairs({len(scale_offset_pairs)}) ' + f'must match num of total_layers({total_layers})') + + cls.quant_meta.update({ + 'k_scales': itertools.cycle(k_scales), + 'k_zeros': itertools.cycle(k_zeros), + 'v_scales': itertools.cycle(v_scales), + 'v_zeros': itertools.cycle(v_zeros), + 'kv_scales': itertools.cycle(kv_scales), + 'kv_zeros': itertools.cycle(kv_zeros) + }) + cls.has_set_value = True + + class AscendOpsBackend(DlinferOpsBackend): """ascend layer backend.""" enable_graph = False @@ -164,6 +233,21 @@ def get_total_slots(): .repeat_interleave(step_context.q_seqlens, 0) kv_seqlens = kv_seqlens_cpu + if not cls.enable_graph and step_context.kv_quant_policy == 8: + record_file = os.getenv('ASCEND_QUANT_RECORD_FILE') + assert record_file, 'please specify valid ASCEND_QUANT_RECORD_FILE' + path = Path(record_file) + is_path = path.is_absolute() or path.is_relative_to('/') + exists = path.exists() + if not (is_path and exists): + raise ValueError( + 'please specify valid ASCEND_QUANT_RECORD_FILE') + if not AscendKVQuantMeta.has_set_value: + total_layers = len(step_context.kv_caches) + AscendKVQuantMeta.set_value(step_context.block_offsets.device, + step_context.model_config.dtype, + record_file, total_layers) + attn_meta_cls = cls.get_attention_metadata_cls() attn_metadata = attn_meta_cls( step_context.is_decoding, @@ -177,6 +261,8 @@ def get_total_slots(): is_unpaged_prefill=is_unpaged_prefill, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, + quant_policy=step_context.kv_quant_policy, + quant_meta=AscendKVQuantMeta.quant_meta, ) step_context.attn_metadata = attn_metadata diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 0d666c9130..6b03403c84 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from dataclasses import dataclass -from typing import Optional, Sequence +from typing import Dict, Optional, Sequence from torch import Tensor @@ -15,6 +15,7 @@ class DlinferAttentionMetadata(AttentionMetadata): is_unpaged_prefill: Optional[bool] = None max_q_seq_len: int = 1 max_kv_seq_len: int = 1 + quant_meta: Dict = None class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]): @@ -30,8 +31,10 @@ def __init__( alibi: bool = None, sliding_window: int = None, logit_softcapping: float = None, + causal: bool = True, **kwargs, ): + assert causal super().__init__( num_heads, head_size, @@ -41,6 +44,7 @@ def __init__( alibi, sliding_window, logit_softcapping, + causal=causal, **kwargs, ) @@ -74,10 +78,37 @@ def forward( is_unpaged_prefill = attn_metadata.is_unpaged_prefill max_q_seq_len = attn_metadata.max_q_seq_len max_kv_seq_len = attn_metadata.max_kv_seq_len + quant_bits = attn_metadata.quant_policy + if attn_metadata.quant_meta is not None: + k_scales_zeros = [ + next(attn_metadata.quant_meta['k_scales']), + next(attn_metadata.quant_meta['k_zeros']) + ] if 'k_scales' in attn_metadata.quant_meta else [] + v_scales_zeros = [ + next(attn_metadata.quant_meta['v_scales']), + next(attn_metadata.quant_meta['v_zeros']) + ] if 'v_scales' in attn_metadata.quant_meta else [] + kv_scales = next( + attn_metadata.quant_meta['kv_scales'] + ) if 'kv_scales' in attn_metadata.quant_meta else None + kv_zeros = next( + attn_metadata.quant_meta['kv_zeros'] + ) if 'kv_zeros' in attn_metadata.quant_meta else None + else: + k_scales_zeros = [] + v_scales_zeros = [] + kv_scales = None + kv_zeros = None # fill kv cache - k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache, - kv_start_indices) + k_cache, v_cache = self.fill_kv_cache(key, + value, + k_cache, + v_cache, + kv_start_indices, + k_scales_zeros=k_scales_zeros, + v_scales_zeros=v_scales_zeros, + quant_bits=quant_bits) if inplace: attn_output = query[..., :self.v_head_size] @@ -103,6 +134,9 @@ def forward( block_size=block_size, attn_mask=attn_mask, is_unpaged_prefill=is_unpaged_prefill, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, ) return attn_output @@ -121,6 +155,7 @@ def build( alibi_scale: float = None, sliding_window: int = None, logical_softcapping: float = None, + causal: bool = True, **kwargs, ) -> DlinferAttentionImpl: """build.""" @@ -132,4 +167,5 @@ def build( alibi_scale=alibi_scale, sliding_window=sliding_window, logical_softcapping=logical_softcapping, + causal=causal, **kwargs) diff --git a/lmdeploy/pytorch/backends/dlinfer/flash_attention.py b/lmdeploy/pytorch/backends/dlinfer/flash_attention.py new file mode 100644 index 0000000000..d0d9ddbb26 --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/flash_attention.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import Tensor + +from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl + + +class DlinferFlashAttentionImpl(FlashAttentionImpl): + """dlinfer flash attention implementation.""" + + def __init__( + self, + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logical_softcapping: float = None, + ): + if scale is None: + scale = 1.0 / (head_dim**0.5) + if num_kv_heads is None: + num_kv_heads = num_heads + if v_head_dim is None: + v_head_dim = head_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = scale + self.num_kv_heads = num_kv_heads + self.v_head_dim = v_head_dim + self.causal = causal + self.sliding_window = sliding_window + self.logical_softcapping = logical_softcapping + from lmdeploy.pytorch.kernels.dlinfer import flash_attention_fwd + self.flash_attention_fwd = flash_attention_fwd + + def forward(self, + query: Tensor, + key: Tensor, + value: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_start_loc: Tensor, + kv_seqlens: Tensor, + max_q_seqlen: int = None): + """forward.""" + q_shape = query.shape + o_shape = q_shape[:-1] + (self.v_head_dim, ) + out = query.new_empty(o_shape) + self.flash_attention_fwd( + query, + key, + value, + out, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_start_loc=kv_start_loc, + kv_seqlens=kv_seqlens, + max_q_seqlen=max_q_seqlen, + window_size=self.sliding_window, + sm_scale=self.scale, + logit_softcapping=self.logical_softcapping, + causal=self.causal, + ) + return out + + +class DlinferFlashAttentionBuilder(FlashAttentionBuilder): + """dlinfer attention builder.""" + + @staticmethod + def build( + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logical_softcapping: float = None, + **kwargs, + ) -> FlashAttentionImpl: + """build.""" + return DlinferFlashAttentionImpl( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + v_head_dim=v_head_dim, + causal=causal, + sliding_window=sliding_window, + logical_softcapping=logical_softcapping, + ) diff --git a/lmdeploy/pytorch/backends/dlinfer/moe.py b/lmdeploy/pytorch/backends/dlinfer/moe.py index 6ada730fbe..ff986c5765 100644 --- a/lmdeploy/pytorch/backends/dlinfer/moe.py +++ b/lmdeploy/pytorch/backends/dlinfer/moe.py @@ -47,8 +47,8 @@ def forward(self, down_weights: torch.Tensor, expert_list: List[int] = None): """forward.""" - return fused_moe(hidden_states, self.top_k, topk_ids, topk_weights, - gate_up_weights, down_weights) + return fused_moe(hidden_states, gate_up_weights, down_weights, + topk_weights, topk_ids, self.top_k, self.renormalize) class DlinferFusedMoEBuilder(FusedMoEBuilder): diff --git a/lmdeploy/pytorch/backends/dlinfer/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/op_backend.py index 52a8830595..a0f04f34b1 100644 --- a/lmdeploy/pytorch/backends/dlinfer/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/op_backend.py @@ -22,9 +22,12 @@ def get_name() -> str: @classmethod def get_layer_impl_builder(cls, layer_type: OpType): """get dlinfer layer builder.""" - if layer_type == OpType.Attention: + if layer_type == OpType.PagedAttention: from .attention import DlinferAttentionBuilder return DlinferAttentionBuilder + elif layer_type == OpType.FlashAttention: + from .flash_attention import DlinferFlashAttentionBuilder + return DlinferFlashAttentionBuilder elif layer_type == OpType.ApplyRotaryEmb: from .apply_rotary_emb import DlinferApplyRotaryEmbBuilder return DlinferApplyRotaryEmbBuilder diff --git a/lmdeploy/pytorch/backends/flash_attention.py b/lmdeploy/pytorch/backends/flash_attention.py new file mode 100644 index 0000000000..bed3af8d68 --- /dev/null +++ b/lmdeploy/pytorch/backends/flash_attention.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + +from torch import Tensor + + +class FlashAttentionImpl(ABC): + """FlashAttention implementation.""" + + def forward(self, + query: Tensor, + key: Tensor, + value: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_start_loc: Tensor, + kv_seqlens: Tensor, + max_q_seqlen: int = None): + """forward.""" + raise NotImplementedError + + +class FlashAttentionBuilder(ABC): + """FlashAttention implementation builder.""" + + @staticmethod + @abstractmethod + def build( + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logical_softcapping: float = None, + **kwargs, + ) -> FlashAttentionImpl: + """build.""" + raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py index 9ab66b26a2..9347995e0b 100644 --- a/lmdeploy/pytorch/backends/graph_runner.py +++ b/lmdeploy/pytorch/backends/graph_runner.py @@ -46,3 +46,26 @@ def prepare_inputs_for_generation( inputs_embeds, context, ) + + def update_model_metas( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare inputs.""" + if hasattr(self.model, 'update_model_metas'): + return self.model.update_model_metas( + past_key_values, + inputs_embeds, + context, + ) + + return None + + def get_input_processor(self): + """get input processor.""" + if hasattr(self.model, 'get_input_processor'): + return self.model.get_input_processor() + else: + return None diff --git a/lmdeploy/pytorch/backends/moe.py b/lmdeploy/pytorch/backends/moe.py index 8e7977625e..4501e52c0b 100644 --- a/lmdeploy/pytorch/backends/moe.py +++ b/lmdeploy/pytorch/backends/moe.py @@ -60,3 +60,93 @@ class FusedMoEBuilder(ABC): def build(top_k: int, num_experts: int, renormalize: bool = False): """build from mlp.""" raise NotImplementedError + + +class FusedMoEW8A8Impl(ABC): + """fused moe w8a8 implementation.""" + + def update_weights(self, gate_up_weights: torch.Tensor, + down_weights: torch.Tensor, gate_up_scale: torch.Tensor, + down_scale: torch.Tensor): + """update weights.""" + return gate_up_weights, down_weights, gate_up_scale, down_scale + + def support_ep(self): + """support expert parallelism.""" + return False + + def ep_expert_list(self, world_size: int, rank: int): + """experts list of current rank.""" + raise NotImplementedError('Not Implemented.') + + @abstractmethod + def forward(self, + hidden_states: torch.Tensor, + input_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + gate_up_weights: torch.Tensor, + gate_up_scale: torch.Tensor, + down_weights: torch.Tensor, + down_scale: torch.Tensor, + expert_list: List[int] = None): + """forward.""" + raise NotImplementedError + + +class FusedMoEW8A8Builder(ABC): + """fused moe w8a8 builder.""" + + @staticmethod + @abstractmethod + def build(top_k: int, + num_experts: int, + renormalize: bool = False, + out_dtype: torch.dtype = torch.float16): + """build from mlp.""" + raise NotImplementedError + + +class FusedMoEBlockedF8Impl(ABC): + """fused moe blocked f8 implementation.""" + + def update_weights(self, gate_up_weights: torch.Tensor, + down_weights: torch.Tensor, gate_up_scale: torch.Tensor, + down_scale: torch.Tensor): + """update weights.""" + return gate_up_weights, down_weights, gate_up_scale, down_scale + + def support_ep(self): + """support expert parallelism.""" + return False + + def ep_expert_list(self, world_size: int, rank: int): + """experts list of current rank.""" + raise NotImplementedError('Not Implemented.') + + @abstractmethod + def forward(self, + hidden_states: torch.Tensor, + input_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + gate_up_weights: torch.Tensor, + gate_up_scale: torch.Tensor, + down_weights: torch.Tensor, + down_scale: torch.Tensor, + expert_list: List[int] = None): + """forward.""" + raise NotImplementedError + + +class FusedMoEBlockedF8Builder(ABC): + """fused moe blocked f8 builder.""" + + @staticmethod + @abstractmethod + def build(top_k: int, + num_experts: int, + renormalize: bool = False, + out_dtype: torch.dtype = torch.float16): + """build from mlp.""" + raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/qmodules.py b/lmdeploy/pytorch/backends/qmodules.py index a61941b37d..e877a4ca6b 100644 --- a/lmdeploy/pytorch/backends/qmodules.py +++ b/lmdeploy/pytorch/backends/qmodules.py @@ -37,7 +37,9 @@ class RMSNormW8A8Builder(ABC): @staticmethod @abstractmethod - def build(hidden_size: int, eps: float = 1e-6): + def build(hidden_size: int, + eps: float = 1e-6, + quant_dtype: torch.dtype = torch.int8): """build.""" raise NotImplementedError @@ -71,6 +73,7 @@ class LinearW8A8Builder(ABC): def build(in_features: int, out_features: int, bias: bool = True, - dtype: torch.dtype = None): + dtype: torch.dtype = None, + quant_dtype: torch.dtype = torch.int8): """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 7d72438224..bc95a32be6 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -1,277 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -from logging import Logger -from typing import List - -from lmdeploy.utils import get_logger - - -def _handle_exception(e: Exception, - mod_name: str, - logger: Logger, - message: str = None): - red_color = '\033[31m' - reset_color = '\033[0m' - if message is None: - message = 'Please ensure it has been installed correctly.' - logger.debug('Exception', exc_info=1) - logger.error(f'{type(e).__name__}: {e}') - logger.error(f'{red_color}' - f'<{mod_name}> test failed!\n' - f'{message}' - f'{reset_color}') - exit(1) +from .base import BaseChecker # noqa: F401 def check_env_deeplink(device_type: str): """check Deeplink environment.""" - try_import_deeplink(device_type) + from .deeplink import DeeplinkChecker + checker = DeeplinkChecker(device_type) + checker.handle() def try_import_deeplink(device_type: str): - """import dlinfer if specific device_type is set.""" - deeplink_device_type_list = [ - 'ascend', - 'npu', - 'maca', - ] - if device_type in deeplink_device_type_list: - logger = get_logger('lmdeploy') - try: - import dlinfer.framework.lmdeploy_ext # noqa: F401 - except Exception as e: - _handle_exception(e, 'PyTorch', logger) - - -def check_env_torch(): - """check PyTorch environment.""" - logger = get_logger('lmdeploy') - - try: - logger.debug('Checking environment.') - import torch - - a = torch.tensor([1, 2], device='cuda') - b = a.new_tensor([3, 4], device='cuda') - c = a + b - torch.testing.assert_close(c, a.new_tensor([4, 6])) - except Exception as e: - _handle_exception(e, 'PyTorch', logger) - - -MAX_TRITON_VERSION = '3.0.0' - - -def check_env_triton(device: str): - """check OpenAI Triton environment.""" - from packaging import version - logger = get_logger('lmdeploy') - - msg = ( - 'Please ensure that your device is functioning properly with .\n' # noqa: E501 - 'You can verify your environment by running ' - '`python -m lmdeploy.pytorch.check_env.triton_custom_add`.') - try: - logger.debug('Checking environment.') - import torch - import triton - triton_version = version.parse(triton.__version__) - if triton_version > version.parse(MAX_TRITON_VERSION): - logger.warning( - f'Engine has not been tested on triton>{MAX_TRITON_VERSION}.') - - from .triton_custom_add import custom_add - a = torch.tensor([1, 2], device='cuda') - b = a.new_tensor([3, 4], device='cuda') - c = custom_add(a, b) - torch.testing.assert_close(c, a + b) - except RuntimeError as e: - ptxas_error = 'device kernel image is invalid' - if len(e.args) > 0 and ptxas_error in e.args[0]: - msg = ( - 'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n' # noqa: E501 - 'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209' # noqa: E501 - ' or reinstall the driver.') - _handle_exception(e, 'Triton', logger, msg) - except Exception as e: - _handle_exception(e, 'Triton', logger, msg) - - if device == 'cuda': - device_cap = torch.cuda.get_device_capability() - TRITON_VER_231 = version.parse('2.3.1') - - if device_cap[0] <= 7: - if triton_version <= TRITON_VER_231: - err = RuntimeError( - 'Attention triton kernel does not fully support ' - 'triton<3.0.0 on device with capability<8. ' - 'Please upgrade your triton version.') - _handle_exception(err, 'Triton', logger) - - -def check_env(device_type: str): - """check all environment.""" - logger = get_logger('lmdeploy') - logger.info('Checking environment for PyTorch Engine.') + """check Deeplink environment.""" check_env_deeplink(device_type) - check_env_torch() - if device_type == 'cuda': - check_env_triton('cuda') - - -MIN_TRANSFORMERS_VERSION = '4.33.0' -MAX_TRANSFORMERS_VERSION = '4.44.1' - - -def check_awq(hf_config, device_type): - """check awq support.""" - logger = get_logger('lmdeploy') - if device_type == 'cuda': - quantization_config = getattr(hf_config, 'quantization_config', dict()) - quant_method = quantization_config.get('quant_method', None) - if quant_method != 'awq': - return - try: - import awq # noqa - except Exception as e: - _handle_exception(e, 'autoawq', logger) - - try: - import awq_ext # noqa - except Exception: - logger.debug('Exception:', exc_info=1) - logger.warning('Failed to import `awq_ext`. ' - 'Try reinstall it from source: ' - 'https://github.com/casper-hansen/AutoAWQ_kernels') - - -def check_transformers_version(model_path: str, - trust_remote_code: bool = True, - dtype: str = 'auto', - device_type: str = 'cuda'): - """check transformers version.""" - from packaging import version - logger = get_logger('lmdeploy') - - def __check_transformers_version(): - """check transformers version.""" - logger.debug('Checking version.') - trans_version = None - try: - import transformers - trans_version = version.parse(transformers.__version__) - min_version = version.parse(MIN_TRANSFORMERS_VERSION) - max_version = version.parse(MAX_TRANSFORMERS_VERSION) - if trans_version < min_version or trans_version > max_version: - logger.warning('LMDeploy requires transformers version: ' - f'[{MIN_TRANSFORMERS_VERSION} ~ ' - f'{MAX_TRANSFORMERS_VERSION}], ' - 'but found version: ' - f'{transformers.__version__}') - except Exception as e: - _handle_exception(e, 'transformers', logger) - return transformers, trans_version - - def __check_config(trans_version): - """check config.""" - logger.debug('Checking AutoConfig.from_pretrained.') - try: - from transformers import AutoConfig - config = AutoConfig.from_pretrained( - model_path, trust_remote_code=trust_remote_code) - except Exception as e: - message = ( - f'Load model config with transformers=={trans_version}' - ' failed. ' - 'Please make sure model can be loaded with transformers API.') - _handle_exception(e, 'transformers', logger, message=message) - return config - - def __check_model_transformers_version(config, trans_version): - """check model transformers version.""" - logger.debug('Checking required transformers version.') - try: - model_trans_version = getattr(config, 'transformers_version', None) - if model_trans_version is not None: - model_trans_version = version.parse(model_trans_version) - assert trans_version >= model_trans_version, \ - 'Version mismatch.' - except Exception as e: - message = (f'model `{model_path}` requires ' - f'transformers version {model_trans_version} ' - f'but transformers {trans_version} is installed.') - _handle_exception(e, 'transformers', logger, message=message) - - def __check_model_dtype_support(config, device_type): - """Checking model dtype support.""" - logger.debug('Checking dtype support.') - - import torch - - from lmdeploy.pytorch.config import ModelConfig - from lmdeploy.utils import is_bf16_supported - - try: - model_config = ModelConfig.from_hf_config(config, - model_path=model_path, - dtype=dtype) - if model_config.dtype == torch.bfloat16: - assert is_bf16_supported(device_type), ( - 'bf16 is not supported on your device') - except AssertionError as e: - message = ( - f'Your device does not support `{model_config.dtype}`. ' - 'You can set `dtype` to float16 in PyTorchEngineConfig or ' - '`--dtype float16` to api_server.\n' - 'Note that this might have negative effect!') - _handle_exception(e, 'Model', logger, message=message) - except Exception as e: - message = (f'Checking failed with error {e}', - 'Please send issue to LMDeploy with error logs.') - _handle_exception(e, 'Model', logger, message=message) - - return model_config - - _, trans_version = __check_transformers_version() - config = __check_config(trans_version) - __check_model_transformers_version(config, trans_version) - __check_model_dtype_support(config, device_type) - check_awq(config, device_type) - - -def check_model(model_path: str, - trust_remote_code: bool = True, - dtype: str = 'auto', - device_type: str = 'cuda'): - """check model requirements.""" - logger = get_logger('lmdeploy') - logger.info('Checking model.') - check_transformers_version(model_path, trust_remote_code, dtype, - device_type) - - -def check_adapter(path: str): - """check adapter.""" - logger = get_logger('lmdeploy') - logger.debug(f'Checking : {path}.') - - try: - from peft import PeftConfig - PeftConfig.from_pretrained(path) - except Exception as e: - message = ('Please make sure the adapter can be loaded with ' - '`peft.PeftConfig.from_pretrained`\n') - err_msg = '' if len(e.args) == 0 else e.args[0] - if 'got an unexpected keyword argument' in err_msg: - message += ('Or try remove all unexpected keywords ' - 'in `adapter_config.json`.') - _handle_exception(e, 'Model', logger, message=message) - - -def check_adapters(adapter_paths: List[str]): - """check adapters.""" - if len(adapter_paths) <= 0: - return - logger = get_logger('lmdeploy') - logger.info('Checking adapters.') - for path in adapter_paths: - check_adapter(path) diff --git a/lmdeploy/pytorch/check_env/adapter.py b/lmdeploy/pytorch/check_env/adapter.py new file mode 100644 index 0000000000..bcaf5fd0e3 --- /dev/null +++ b/lmdeploy/pytorch/check_env/adapter.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseChecker + + +class AdapterChecker(BaseChecker): + """check adapter is available.""" + + def __init__(self, adapter_path: str, logger=None): + super().__init__(logger) + self.adapter_path = adapter_path + + def check(self): + """check.""" + path = self.adapter_path + + try: + import peft # noqa: F401 + except Exception as e: + self.log_and_exit(e, 'Adapter', message='Failed to import peft.') + + try: + from peft import PeftConfig + PeftConfig.from_pretrained(path) + except Exception as e: + message = ('Please make sure the adapter can be loaded with ' + '`peft.PeftConfig.from_pretrained`\n') + err_msg = '' if len(e.args) == 0 else e.args[0] + if 'got an unexpected keyword argument' in err_msg: + message += ('Or try remove all unexpected keywords ' + 'in `adapter_config.json`.') + self.log_and_exit(e, 'Adapter', message=message) diff --git a/lmdeploy/pytorch/check_env/base.py b/lmdeploy/pytorch/check_env/base.py new file mode 100644 index 0000000000..ed5e5a600f --- /dev/null +++ b/lmdeploy/pytorch/check_env/base.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from logging import Logger +from typing import List + +from lmdeploy.utils import get_logger + +RED_COLOR = '\033[31m' +RESET_COLOR = '\033[0m' + + +def _red_text(text: str): + """red text.""" + return f'{RED_COLOR}{text}{RESET_COLOR}' + + +class BaseChecker: + """base checker.""" + + def __init__(self, logger: Logger = None): + if logger is None: + logger = get_logger('lmdeploy') + self.logger = logger + self._is_passed = False + self._required_checker: List[BaseChecker] = list() + + def get_logger(self): + """get logger.""" + return self.logger + + def register_required_checker(self, checker: 'BaseChecker'): + """register_required.""" + self._required_checker.append(checker) + + def handle(self): + """handle check.""" + is_passed = getattr(self, '_is_passed', False) + if not is_passed: + checker_name = type(self).__name__ + self.logger.debug(f'Checking <{checker_name}>:') + for checker in self._required_checker: + checker.handle() + self.check() + self.is_passed = True + + def log_and_exit(self, + e: Exception = None, + mod_name: str = None, + message: str = None): + logger = self.logger + if mod_name is None: + mod_name = type(self).__name__ + if message is None: + message = 'Please check your environment.' + logger.debug('Exception', exc_info=1) + if e is not None: + logger.error(f'{type(e).__name__}: {e}') + logger.error(f'<{mod_name}> check failed!\n{_red_text(message)}') + exit(1) + + def check(self): + """check.""" + raise NotImplementedError('check not implemented.') diff --git a/lmdeploy/pytorch/check_env/deeplink.py b/lmdeploy/pytorch/check_env/deeplink.py new file mode 100644 index 0000000000..74ab5a7b87 --- /dev/null +++ b/lmdeploy/pytorch/check_env/deeplink.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseChecker + +deeplink_device_type_list = [ + 'ascend', + 'npu', + 'maca', +] + + +class DeeplinkChecker(BaseChecker): + """check pytorch is available.""" + + def __init__(self, device_type: str, logger=None) -> None: + super().__init__(logger=logger) + self.device_type = device_type + + def check(self): + """check.""" + device_type = self.device_type + if device_type in deeplink_device_type_list: + try: + import dlinfer.framework.lmdeploy_ext # noqa: F401 + except Exception as e: + self.log_and_exit(e, 'dlinfer', 'dlinfer is not available.') diff --git a/lmdeploy/pytorch/check_env/model.py b/lmdeploy/pytorch/check_env/model.py new file mode 100644 index 0000000000..79d8d26e3c --- /dev/null +++ b/lmdeploy/pytorch/check_env/model.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from packaging import version + +from .base import BaseChecker + + +class ModelChecker(BaseChecker): + """check model is available.""" + + def __init__(self, + model_path: str, + trust_remote_code: bool, + dtype: str, + device_type: str, + logger=None) -> None: + super().__init__(logger=logger) + self.model_path = model_path + self.trust_remote_code = trust_remote_code + self.device_type = device_type + self.dtype = dtype + + def check_config(self, trans_version): + """check config.""" + model_path = self.model_path + trust_remote_code = self.trust_remote_code + try: + from transformers import AutoConfig + config = AutoConfig.from_pretrained( + model_path, trust_remote_code=trust_remote_code) + except Exception as e: + message = ( + f'Load model config with transformers=={trans_version}' + ' failed. ' + 'Please make sure model can be loaded with transformers API.') + self.log_and_exit(e, 'transformers', message=message) + return config + + def check_trans_version(self, config, trans_version): + """check transformers version.""" + model_path = self.model_path + try: + model_trans_version = getattr(config, 'transformers_version', None) + if model_trans_version is not None: + model_trans_version = version.parse(model_trans_version) + assert trans_version >= model_trans_version, ( + 'Version mismatch.') + except Exception as e: + message = (f'model `{model_path}` requires ' + f'transformers version {model_trans_version} ' + f'but transformers {trans_version} is installed.') + self.log_and_exit(e, 'transformers', message=message) + + def check_dtype(self, config): + """check dtype.""" + logger = self.get_logger() + model_path = self.model_path + device_type = self.device_type + dtype = self.dtype + try: + import torch + + from lmdeploy.pytorch.config import ModelConfig + from lmdeploy.utils import is_bf16_supported + model_config = ModelConfig.from_hf_config(config, + model_path=model_path, + dtype=dtype) + if model_config.dtype == torch.bfloat16: + if not is_bf16_supported(device_type): + logger.warning('Device does not support bfloat16.') + except Exception as e: + message = (f'Checking failed with error {e}', + 'Please send issue to LMDeploy with error logs.') + self.log_and_exit(e, 'Model', message=message) + + def check(self): + """check.""" + import transformers + trans_version = version.parse(transformers.__version__) + + # config + config = self.check_config(trans_version) + + # transformers version + self.check_trans_version(config, trans_version) + + # dtype check + self.check_dtype(config) diff --git a/lmdeploy/pytorch/check_env/torch.py b/lmdeploy/pytorch/check_env/torch.py new file mode 100644 index 0000000000..14b24e04a0 --- /dev/null +++ b/lmdeploy/pytorch/check_env/torch.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseChecker + + +class TorchChecker(BaseChecker): + """check pytorch is available.""" + + def __init__(self, device: str = 'cuda', logger=None) -> None: + super().__init__(logger=logger) + self.device = device + + def check(self): + """check.""" + try: + import torch + a = torch.tensor([1, 2], device=self.device) + b = a.new_tensor([3, 4], device=self.device) + c = a + b + torch.testing.assert_close(c, a.new_tensor([4, 6])) + except Exception as e: + self.log_and_exit(e, 'PyTorch', 'PyTorch is not available.') diff --git a/lmdeploy/pytorch/check_env/transformers.py b/lmdeploy/pytorch/check_env/transformers.py new file mode 100644 index 0000000000..9d97cd6dca --- /dev/null +++ b/lmdeploy/pytorch/check_env/transformers.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from packaging import version + +from .base import BaseChecker + +MIN_TRANSFORMERS_VERSION = '4.33.0' +MAX_TRANSFORMERS_VERSION = '4.46.1' + + +class TransformersChecker(BaseChecker): + """check transformers is available.""" + + def check(self): + """check.""" + import transformers + logger = self.get_logger() + try: + trans_version = version.parse(transformers.__version__) + min_version = version.parse(MIN_TRANSFORMERS_VERSION) + max_version = version.parse(MAX_TRANSFORMERS_VERSION) + if trans_version < min_version or trans_version > max_version: + logger.warning('LMDeploy requires transformers version: ' + f'[{MIN_TRANSFORMERS_VERSION} ~ ' + f'{MAX_TRANSFORMERS_VERSION}], ' + 'but found version: ' + f'{transformers.__version__}') + except Exception as e: + self.log_and_exit(e, 'transformers', + 'transformers is not available.') diff --git a/lmdeploy/pytorch/check_env/triton.py b/lmdeploy/pytorch/check_env/triton.py new file mode 100644 index 0000000000..4cc58c5492 --- /dev/null +++ b/lmdeploy/pytorch/check_env/triton.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from packaging import version + +from .base import BaseChecker + +MAX_TRITON_VERSION = '3.1.0' +MIN_TRITON_VERSION = '3.0.0' + + +class TritonChecker(BaseChecker): + """check triton is available.""" + + def check_version(self): + """check version.""" + logger = self.get_logger() + + # version check + import triton + max_version = version.parse(MAX_TRITON_VERSION) + min_version = version.parse(MIN_TRITON_VERSION) + triton_version = version.parse(triton.__version__) + + if triton_version > max_version: + logger.warning('PytorchEngine has not been tested on ' + f'triton>{MAX_TRITON_VERSION}.') + if triton_version < min_version: + msg = (f'triton>={MIN_TRITON_VERSION} is required. ' + f'Found triton=={triton_version}') + self.log_and_exit(mod_name='Triton', message=msg) + + def check(self): + """check.""" + logger = self.get_logger() + + msg = ( + 'Please ensure that your device is functioning properly with .\n' # noqa: E501 + 'You can verify your environment by running ' + '`python -m lmdeploy.pytorch.check_env.triton_custom_add`.') + try: + logger.debug('Checking environment.') + import torch + + from .triton_custom_add import custom_add + a = torch.tensor([1, 2], device='cuda') + b = a.new_tensor([3, 4], device='cuda') + c = custom_add(a, b) + torch.testing.assert_close(c, a + b) + except RuntimeError as e: + ptxas_error = 'device kernel image is invalid' + if len(e.args) > 0 and ptxas_error in e.args[0]: + msg = ( + 'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n' # noqa: E501 + 'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209' # noqa: E501 + ' or reinstall the driver.') + self.log_and_exit(e, 'Triton', msg) + except Exception as e: + self.log_and_exit(e, 'Triton', msg) + + # version check + self.check_version() diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index a9381890ee..b2cdc304b7 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -26,6 +26,10 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str): return config torch_dtype = getattr(config.hf_config, 'torch_dtype', None) + # deal with case when torch_dtype is not string but torch.dtype + if isinstance(torch_dtype, torch.dtype): + torch_dtype = str(torch_dtype).split('.')[1] + if torch_dtype is None: _dtype = 'float16' if dtype == 'auto' else dtype logger.warning('Model config does not have `torch_dtype`,' @@ -37,8 +41,8 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str): # change to user specified data type if it is not 'auto' if dtype == 'auto': torch_dtype = torch_dtype if torch_dtype in [ - torch.float16, torch.bfloat16 - ] else torch.float16 + 'float16', 'bfloat16' + ] else 'float16' else: torch_dtype = dtype config.dtype = eval(f'torch.{torch_dtype}') @@ -77,6 +81,7 @@ class CacheConfig: max_prefill_token_num: int = 4096 enable_prefix_caching: bool = False quant_policy: Literal[0, 4, 8] = 0 + device_type: str = 'cuda' def __post_init__(self): """post init.""" @@ -103,7 +108,6 @@ class ModelConfig: v_head_dim: int = None sliding_window: int = -1 dtype: torch.dtype = torch.float16 - multi_query_attention: bool = False vocab_size: int = 40000 hf_config: Any = None cogvlm_style: bool = False @@ -120,7 +124,8 @@ def get_head_size(self): def from_pretrained(cls, pretrained_model_name_or_path: str, trust_remote_code: bool = True, - dtype: str = 'auto'): + dtype: str = 'auto', + tp: int = 1): """Instantiate one of the configuration classes of the library from a pretrained model configuration. @@ -150,17 +155,21 @@ def from_pretrained(cls, setattr(hf_config, 'architectures', ['MedusaModel']) return cls.from_hf_config(hf_config, pretrained_model_name_or_path, - dtype=dtype) + dtype=dtype, + tp=tp) @classmethod def from_hf_config(cls, hf_config: Any, model_path: str = None, - dtype: str = 'auto'): + dtype: str = 'auto', + tp: int = 1): """from huggingface config.""" from lmdeploy.pytorch.configurations import AutoModelConfigBuilder - model_config = AutoModelConfigBuilder.build(hf_config, model_path) + model_config = AutoModelConfigBuilder.build(hf_config, + model_path, + tp=tp) if model_config.k_head_dim is None: assert model_config.head_dim is not None @@ -169,6 +178,13 @@ def from_hf_config(cls, assert model_config.head_dim is not None model_config.v_head_dim = model_config.head_dim + # check for tp + assert model_config.num_attention_heads % tp == 0 + if model_config.num_key_value_heads >= tp: + assert model_config.num_key_value_heads % tp == 0 + else: + assert tp % model_config.num_key_value_heads == 0 + # should after setting `hf_config` and `model_arch` attributes model_config = _update_torch_dtype(model_config, dtype) diff --git a/lmdeploy/pytorch/configurations/builder.py b/lmdeploy/pytorch/configurations/builder.py index 89bf51ca46..bafa78ba02 100644 --- a/lmdeploy/pytorch/configurations/builder.py +++ b/lmdeploy/pytorch/configurations/builder.py @@ -27,7 +27,7 @@ def condition(cls, hf_config): f'`condition` of {cls.__name__} not implemented.') @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" from .default import DefaultModelConfigBuilder @@ -46,8 +46,21 @@ def build(cls, hf_config, model_path: str = None): logger.debug(f'build model config with {valid_builder.__name__}') - cfg = valid_builder.build(hf_config, model_path) + cfg = valid_builder.build(hf_config, model_path, **kwargs) if cfg.hf_config is None: cfg.hf_config = hf_config return cfg + + @classmethod + def update_num_kv_heads(cls, hf_config, tp, num_key_value_heads): + """update num kv heads.""" + # update num_kv_heads for tp mode + if tp > 1 and tp > num_key_value_heads: + assert tp % num_key_value_heads == 0 + n_replicate = tp // num_key_value_heads + hf_config.num_replicate_key_value_heads = n_replicate + num_key_value_heads = tp + + hf_config.num_key_value_heads = num_key_value_heads + return num_key_value_heads diff --git a/lmdeploy/pytorch/configurations/chatglm.py b/lmdeploy/pytorch/configurations/chatglm.py index 7911c985d5..fbf4d48281 100644 --- a/lmdeploy/pytorch/configurations/chatglm.py +++ b/lmdeploy/pytorch/configurations/chatglm.py @@ -12,16 +12,27 @@ def condition(cls, hf_config): return hf_config.model_type == 'chatglm' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" head_dim = hf_config.hidden_size // hf_config.num_attention_heads bos_token_id = hf_config.bos_token_id if bos_token_id is None: bos_token_id = hf_config.pad_token_id + + if hf_config.multi_query_attention: + num_key_value_heads = hf_config.multi_query_group_num + else: + num_key_value_heads = hf_config.num_attention_heads + + tp = kwargs.get('tp', 1) + # update num_kv_heads for tp mode + num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, + num_key_value_heads) + cfg = ModelConfig(hidden_size=hf_config.hidden_size, num_layers=hf_config.num_layers, num_attention_heads=hf_config.num_attention_heads, - num_key_value_heads=hf_config.multi_query_group_num, + num_key_value_heads=num_key_value_heads, bos_token_id=bos_token_id, eos_token_id=hf_config.eos_token_id, head_dim=head_dim, diff --git a/lmdeploy/pytorch/configurations/cogvlm.py b/lmdeploy/pytorch/configurations/cogvlm.py index b24d92d794..4736dfee69 100644 --- a/lmdeploy/pytorch/configurations/cogvlm.py +++ b/lmdeploy/pytorch/configurations/cogvlm.py @@ -12,12 +12,15 @@ def condition(cls, hf_config): return model_arch == 'CogVLMForCausalLM' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" from lmdeploy.utils import is_bf16_supported - cfg = DefaultModelConfigBuilder.build(hf_config) if getattr(hf_config, 'num_multi_query_heads', None): - cfg.num_key_value_heads = hf_config.num_multi_query_heads + hf_config.num_key_value_heads = hf_config.num_multi_query_heads + else: + hf_config.num_key_value_heads = hf_config.num_attention_heads + + cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs) cfg.cogvlm_style = True torch_dtype = 'bfloat16' if is_bf16_supported() else 'float16' hf_config.torch_dtype = torch_dtype diff --git a/lmdeploy/pytorch/configurations/dbrx.py b/lmdeploy/pytorch/configurations/dbrx.py index 2c8128a5a6..dcc1222b0d 100644 --- a/lmdeploy/pytorch/configurations/dbrx.py +++ b/lmdeploy/pytorch/configurations/dbrx.py @@ -12,7 +12,7 @@ def condition(cls, hf_config): return hf_config.model_type == 'dbrx' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" hidden_size = hf_config.d_model num_heads = hf_config.n_heads diff --git a/lmdeploy/pytorch/configurations/deepseek_v2.py b/lmdeploy/pytorch/configurations/deepseek_v2.py index 37aa4b0d69..bf06ff0c33 100644 --- a/lmdeploy/pytorch/configurations/deepseek_v2.py +++ b/lmdeploy/pytorch/configurations/deepseek_v2.py @@ -9,16 +9,22 @@ class DeepseekV2ModelConfigBuilder(AutoModelConfigBuilder): @classmethod def condition(cls, hf_config): """config.""" - return hf_config.model_type == 'deepseek_v2' + return hf_config.model_type in ['deepseek_v3', 'deepseek_v2'] @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" head_dim = (hf_config.kv_lora_rank + hf_config.qk_rope_head_dim) k_head_dim = head_dim v_head_dim = 0 num_attention_heads = hf_config.num_attention_heads + # multi query attn num_key_value_heads = 1 + tp = kwargs.get('tp', 1) + # update num_kv_heads for tp mode + num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, + num_key_value_heads) + return ModelConfig(hidden_size=hf_config.hidden_size, num_layers=hf_config.num_hidden_layers, num_attention_heads=num_attention_heads, @@ -28,5 +34,4 @@ def build(cls, hf_config, model_path: str = None): head_dim=head_dim, k_head_dim=k_head_dim, v_head_dim=v_head_dim, - vocab_size=hf_config.vocab_size, - multi_query_attention=True) + vocab_size=hf_config.vocab_size) diff --git a/lmdeploy/pytorch/configurations/default.py b/lmdeploy/pytorch/configurations/default.py index 1f84b810ea..d1337a241e 100644 --- a/lmdeploy/pytorch/configurations/default.py +++ b/lmdeploy/pytorch/configurations/default.py @@ -12,7 +12,7 @@ def condition(cls, hf_config): return True @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" head_dim = hf_config.hidden_size // hf_config.num_attention_heads num_attention_heads = hf_config.num_attention_heads @@ -23,6 +23,11 @@ def build(cls, hf_config, model_path: str = None): if use_sliding_window: sliding_window = getattr(hf_config, 'sliding_window', sliding_window) or -1 + tp = kwargs.get('tp', 1) + # update num_kv_heads for tp mode + num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, + num_key_value_heads) + return ModelConfig( hidden_size=hf_config.hidden_size, num_layers=hf_config.num_hidden_layers, diff --git a/lmdeploy/pytorch/configurations/falcon.py b/lmdeploy/pytorch/configurations/falcon.py index db4d00e397..a4c8d4d44f 100644 --- a/lmdeploy/pytorch/configurations/falcon.py +++ b/lmdeploy/pytorch/configurations/falcon.py @@ -12,7 +12,7 @@ def condition(cls, hf_config): return hf_config.model_type == 'falcon' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build falcon.""" num_attention_heads = hf_config.num_attention_heads if hf_config.new_decoder_architecture: @@ -24,6 +24,12 @@ def build(cls, hf_config, model_path: str = None): else: # rw-1b, MHA kv_head = num_attention_heads + + tp = kwargs.get('tp', 1) + # update num_kv_heads for tp mode + kv_head = cls.update_num_kv_heads(hf_config, tp, kv_head) + hf_config.num_kv_heads = kv_head + head_dim = hf_config.hidden_size // num_attention_heads return ModelConfig( hidden_size=hf_config.hidden_size, @@ -33,6 +39,5 @@ def build(cls, hf_config, model_path: str = None): bos_token_id=hf_config.bos_token_id, eos_token_id=hf_config.eos_token_id, head_dim=head_dim, - multi_query_attention=hf_config.multi_query, vocab_size=hf_config.vocab_size, ) diff --git a/lmdeploy/pytorch/configurations/gemma.py b/lmdeploy/pytorch/configurations/gemma.py index 338eaee6d0..d49fdbd96c 100644 --- a/lmdeploy/pytorch/configurations/gemma.py +++ b/lmdeploy/pytorch/configurations/gemma.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from lmdeploy.pytorch.config import ModelConfig - from .builder import AutoModelConfigBuilder +from .default import DefaultModelConfigBuilder class GemmaModelConfigBuilder(AutoModelConfigBuilder): @@ -12,13 +11,8 @@ def condition(cls, hf_config): return hf_config.model_type in ['gemma', 'gemma2'] @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build gemma.""" - return ModelConfig(hidden_size=hf_config.hidden_size, - num_layers=hf_config.num_hidden_layers, - num_attention_heads=hf_config.num_attention_heads, - num_key_value_heads=hf_config.num_key_value_heads, - bos_token_id=hf_config.bos_token_id, - eos_token_id=hf_config.eos_token_id, - head_dim=hf_config.head_dim, - vocab_size=hf_config.vocab_size) + cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs) + cfg.head_dim = hf_config.head_dim + return cfg diff --git a/lmdeploy/pytorch/configurations/internvl.py b/lmdeploy/pytorch/configurations/internvl.py index 76b4187c5f..ffff0a0e15 100644 --- a/lmdeploy/pytorch/configurations/internvl.py +++ b/lmdeploy/pytorch/configurations/internvl.py @@ -11,8 +11,9 @@ def condition(cls, hf_config): return hf_config.architectures[0] == 'InternVLChatModel' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build llava hf.""" - cfg = DefaultModelConfigBuilder.build(hf_config.llm_config) + cfg = DefaultModelConfigBuilder.build(hf_config.llm_config, model_path, + **kwargs) cfg.hf_config = hf_config return cfg diff --git a/lmdeploy/pytorch/configurations/llava.py b/lmdeploy/pytorch/configurations/llava.py deleted file mode 100644 index aaeeeeadfe..0000000000 --- a/lmdeploy/pytorch/configurations/llava.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .builder import AutoModelConfigBuilder -from .default import DefaultModelConfigBuilder - - -class LlavaModelConfigBuilder(AutoModelConfigBuilder): - - @classmethod - def condition(cls, hf_config): - """config.""" - return hf_config.architectures[0] in [ - 'LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM' - ] - - @classmethod - def build(cls, hf_config, model_path: str = None): - """build.""" - arch = hf_config.architectures[0] - if arch in ['LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM']: - from llava.model.language_model.llava_llama import LlavaConfig - - # reload hf_config due to model_type='llava' is already - # registered in transformers - hf_config = LlavaConfig.from_pretrained(model_path) - cfg = DefaultModelConfigBuilder.build(hf_config) - return cfg diff --git a/lmdeploy/pytorch/configurations/llava_hf.py b/lmdeploy/pytorch/configurations/llava_hf.py index 4cc007e313..5334eaec25 100644 --- a/lmdeploy/pytorch/configurations/llava_hf.py +++ b/lmdeploy/pytorch/configurations/llava_hf.py @@ -15,7 +15,7 @@ def condition(cls, hf_config): ] @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build llava hf.""" text_config = hf_config.text_config hidden_size = getattr(text_config, 'hidden_size', 4096) diff --git a/lmdeploy/pytorch/configurations/medusa.py b/lmdeploy/pytorch/configurations/medusa.py index 4935bc0e25..a4f705cd3f 100644 --- a/lmdeploy/pytorch/configurations/medusa.py +++ b/lmdeploy/pytorch/configurations/medusa.py @@ -12,7 +12,7 @@ def condition(cls, hf_config): return hf_config.architectures[0] == 'MedusaModel' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" from transformers import AutoConfig base_config = AutoConfig.from_pretrained( diff --git a/lmdeploy/pytorch/configurations/minicpm3.py b/lmdeploy/pytorch/configurations/minicpm3.py index 7cde51bd42..857673aab3 100644 --- a/lmdeploy/pytorch/configurations/minicpm3.py +++ b/lmdeploy/pytorch/configurations/minicpm3.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from lmdeploy.pytorch.config import ModelConfig from .builder import AutoModelConfigBuilder +from .default import DefaultModelConfigBuilder class MiniCPM3ModelConfigBuilder(AutoModelConfigBuilder): @@ -12,21 +12,13 @@ def condition(cls, hf_config): return hf_config.architectures[0] in ['MiniCPM3ForCausalLM'] @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" head_dim = (hf_config.qk_nope_head_dim + hf_config.qk_rope_head_dim) - k_head_dim = head_dim - v_head_dim = head_dim - num_attention_heads = hf_config.num_attention_heads - num_key_value_heads = hf_config.num_key_value_heads - return ModelConfig(hidden_size=hf_config.hidden_size, - num_layers=hf_config.num_hidden_layers, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - bos_token_id=hf_config.bos_token_id, - eos_token_id=hf_config.eos_token_id, - head_dim=head_dim, - k_head_dim=k_head_dim, - v_head_dim=v_head_dim, - vocab_size=hf_config.vocab_size, - multi_query_attention=False) + + cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs) + cfg.head_dim = head_dim + cfg.k_head_dim = head_dim + cfg.v_head_dim = head_dim + + return cfg diff --git a/lmdeploy/pytorch/configurations/mllama.py b/lmdeploy/pytorch/configurations/mllama.py index 2383c92c50..e56e0fbed4 100644 --- a/lmdeploy/pytorch/configurations/mllama.py +++ b/lmdeploy/pytorch/configurations/mllama.py @@ -11,8 +11,9 @@ def condition(cls, hf_config): return hf_config.architectures[0] == 'MllamaForConditionalGeneration' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build llava hf.""" - cfg = DefaultModelConfigBuilder.build(hf_config.text_config) + cfg = DefaultModelConfigBuilder.build(hf_config.text_config, + model_path, **kwargs) cfg.hf_config = hf_config return cfg diff --git a/lmdeploy/pytorch/configurations/qwen.py b/lmdeploy/pytorch/configurations/qwen.py index 05ac77c1d1..eda726de43 100644 --- a/lmdeploy/pytorch/configurations/qwen.py +++ b/lmdeploy/pytorch/configurations/qwen.py @@ -11,10 +11,10 @@ def condition(cls, hf_config): return hf_config.model_type == 'qwen' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" from lmdeploy.utils import is_bf16_supported - cfg = DefaultModelConfigBuilder.build(hf_config) + cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs) if cfg.bos_token_id is None: cfg.bos_token_id = 151644 if cfg.eos_token_id is None: diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index e393adeed3..e3f97cfe46 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -44,7 +44,13 @@ def __init__( self.num_layers = model_config.num_layers self.kv_cache_dtype = model_config.dtype if cache_config.quant_policy > 0: - self.kv_cache_dtype = torch.uint8 + if self.cache_config.device_type in ['cuda']: + self.kv_cache_dtype = torch.uint8 + elif self.cache_config.device_type in ['ascend', 'npu']: + self.kv_cache_dtype = torch.int8 + else: + raise ValueError( + f'unsupported device_type {self.cache_config.device_type}') # Initialize the cache. self.local_gpu_cache = self.allocate_gpu_cache() @@ -92,7 +98,7 @@ def _get_key_block_shape_impl(cls, attn_backend = get_backend() dtype = model_config.dtype num_heads = model_config.num_key_value_heads - if local and not model_config.multi_query_attention: + if local: assert num_heads % world_size == 0, \ f'num_heads: {num_heads}, world_size: {world_size}' num_heads = num_heads // world_size @@ -115,7 +121,7 @@ def _get_value_block_shape_impl(cls, attn_backend = get_backend() dtype = model_config.dtype num_heads = model_config.num_key_value_heads - if local and not model_config.multi_query_attention: + if local: assert num_heads % world_size == 0, \ f'num_heads: {num_heads}, world_size: {world_size}' num_heads = num_heads // world_size @@ -202,7 +208,7 @@ def allocate_gpu_cache(self): def allocate_cpu_cache(self): """allocate caches on Host.""" - caches = self._allocate_cache(self.num_gpu_blocks, 'cpu') + caches = self._allocate_cache(self.num_cpu_blocks, 'cpu') self.full_cpu_cache = caches self.local_cpu_cache = list(zip(*caches)) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 263ef784ee..19c8017b7a 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -8,19 +8,17 @@ import numpy as np import torch -from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig, - ResponseType) +from lmdeploy.messages import PytorchEngineConfig, ResponseType from lmdeploy.utils import (get_logger, get_max_batch_size, get_model, logging_timer) from ..adapter.adapter import AdapterManager -from ..check_env import check_adapters, check_env, check_model from ..config import BackendConfig, CacheConfig, SchedulerConfig from ..devices import DeviceContext, get_device_manager -from ..messages import (InputEmbeddingRangeType, InputEmbeddingType, - MessageStatus, SchedulerSequence) -from ..model_inputs import ModelInputs, MRopeModelInputs, VisionModelInputs +from ..messages import MessageStatus, SchedulerSequence +from ..model_inputs import ModelInputs, VisionModelInputs from ..paging import Scheduler +from .engine_checker import EngineChecker from .logits_process import FusedLogitsProcessor, SamplingInputs from .model_agent import build_model_agent from .request import Request, RequestManager, RequestType, Response @@ -32,24 +30,13 @@ _EMPTY_TOKEN = np.empty((0, ), dtype=np.int64) -def _raise_exception_on_finish(task: asyncio.Task) -> None: - """raise exception on finish.""" - try: - task.result() - except asyncio.CancelledError: - return - except Exception as e: - raise e - - @dataclass class InferOutput: """The output of the model inference.""" session_id: int + resp: Response token_ids: List[int] - sender_id: int - req_id: int meta: Any = None finish: bool = False logits: torch.Tensor = None @@ -78,6 +65,40 @@ def _check_finish(scheduler: Scheduler, current_iter: int): return False +def _build_scheduler_config(engine_config: PytorchEngineConfig): + """build scheduler config.""" + scheduler_config = SchedulerConfig( + max_batches=engine_config.max_batch_size, + max_session_len=engine_config.session_len, + prefill_interval=engine_config.prefill_interval) + return scheduler_config + + +def _build_cache_config(engine_config: PytorchEngineConfig): + """build cache config.""" + cache_config = CacheConfig( + max_batches=engine_config.max_batch_size, + block_size=engine_config.block_size, + num_cpu_blocks=engine_config.num_cpu_blocks, + num_gpu_blocks=engine_config.num_gpu_blocks, + cache_max_entry_count=engine_config.cache_max_entry_count, + max_prefill_token_num=engine_config.max_prefill_token_num, + enable_prefix_caching=engine_config.enable_prefix_caching, + quant_policy=engine_config.quant_policy, + device_type=engine_config.device_type, + ) + return cache_config + + +def _build_backend_config(engine_config: PytorchEngineConfig): + """build backend config.""" + backend_config = BackendConfig( + eager_mode=engine_config.eager_mode, + device_type=engine_config.device_type, + ) + return backend_config + + class Engine: """The inference engine of lmdeploy pytorch. @@ -97,43 +118,23 @@ def __init__(self, engine_config = PytorchEngineConfig() else: engine_config = copy.deepcopy(engine_config) - check_env(engine_config.device_type) - check_model(model_path, trust_remote_code, engine_config.dtype, - engine_config.device_type) if engine_config.max_batch_size is None: engine_config.max_batch_size = get_max_batch_size( engine_config.device_type) - adapters = engine_config.adapters - if adapters is not None: - check_adapters(list(adapters.values())) - assert engine_config.max_batch_size > 0, 'max_batch_size should be' \ - f' greater than 0, but got {engine_config.max_batch_size}' - assert engine_config.dtype in ['auto', 'float16', 'bfloat16'], \ - f'unsupported specified data type {engine_config.dtype}' + checker = EngineChecker(model_path=model_path, + engine_config=engine_config, + trust_remote_code=trust_remote_code, + logger=logger) + checker.handle() + + adapters = engine_config.adapters self.engine_config = engine_config self.tp = engine_config.tp self.device_context = DeviceContext( device_type=engine_config.device_type) - scheduler_config = SchedulerConfig( - max_batches=engine_config.max_batch_size, - max_session_len=engine_config.session_len, - prefill_interval=engine_config.prefill_interval) - - # block_size = 1 to enable unified paging - cache_config = CacheConfig( - max_batches=engine_config.max_batch_size, - block_size=engine_config.block_size, - num_cpu_blocks=engine_config.num_cpu_blocks, - num_gpu_blocks=engine_config.num_gpu_blocks, - cache_max_entry_count=engine_config.cache_max_entry_count, - max_prefill_token_num=engine_config.max_prefill_token_num, - enable_prefix_caching=engine_config.enable_prefix_caching, - quant_policy=engine_config.quant_policy, - ) - if not os.path.exists(model_path): model_path = get_model(model_path, engine_config.download_dir, engine_config.revision) @@ -142,10 +143,9 @@ def __init__(self, if adapters is not None and len(adapters) > 0: adapters = self._download_adapters(adapters, engine_config) - backend_config = BackendConfig( - eager_mode=engine_config.eager_mode, - device_type=engine_config.device_type, - ) + scheduler_config = _build_scheduler_config(engine_config) + cache_config = _build_cache_config(engine_config) + backend_config = _build_backend_config(engine_config) with get_device_manager().context(self.device_context): self.model_agent = build_model_agent( @@ -159,6 +159,8 @@ def __init__(self, dtype=engine_config.dtype, custom_module_map=engine_config.custom_module_map) + self.input_processor = self.model_agent.get_input_processor() + cache_config = self.model_agent.cache_config self.adapter_manager = self._build_adapter_manager(adapters) self.scheduler = Scheduler(scheduler_config, cache_config) @@ -174,7 +176,6 @@ def __init__(self, # create main thread self._start_loop() self._create_buffers() - self.engine_instance = self.create_instance() self._output_stream = torch.cuda.Stream() @classmethod @@ -244,7 +245,7 @@ def _build_adapter_manager(self, adapters): def _bind_request_manager(self): """bind request manager.""" - req_manager = RequestManager(self.engine_config.thread_safe) + req_manager = RequestManager() req_manager.bind_func(RequestType.ADD_SESSION, self._on_add_session) req_manager.bind_func(RequestType.STOP_SESSION, self._on_stop_session) req_manager.bind_func(RequestType.END_SESSION, self._on_end_session) @@ -256,18 +257,15 @@ def _start_loop(self): return self.req_manager.start_loop(self.async_loop) def _response(self, + resp: Response, resp_type: ResponseType, - sender_id: int, - req_id: int, data: Any = None, err_msg: str = ''): """response.""" - self.req_manager.response( - Response(type=resp_type, - sender_id=sender_id, - req_id=req_id, - data=data, - err_msg=err_msg)) + resp.type = resp_type + resp.data = data + resp.err_msg = err_msg + self.req_manager.response(resp) def _get_max_session_len(self): """get max session len.""" @@ -293,7 +291,7 @@ def _on_add_session(self, reqs: Request, **kwargs): self.scheduler.add_session(session_id) resp_type = ResponseType.SUCCESS if resp: - self._response(resp_type, req.sender_id, req.req_id) + self._response(req.resp, resp_type) def _on_stop_session(self, reqs: Request, **kwargs): """on stop session callback.""" @@ -305,7 +303,7 @@ def _on_stop_session(self, reqs: Request, **kwargs): self.scheduler.stop_session(session_id) resp_type = ResponseType.SUCCESS if resp: - self._response(resp_type, req.sender_id, req.req_id) + self._response(req.resp, resp_type) def _on_end_session(self, reqs: Request, **kwargs): """on end session callback.""" @@ -317,10 +315,35 @@ def _on_end_session(self, reqs: Request, **kwargs): self.scheduler.end_session(session_id) resp_type = ResponseType.SUCCESS if resp: - self._response(resp_type, req.sender_id, req.req_id) + self._response(req.resp, resp_type) def _on_add_message(self, reqs: Request, **kwargs): """on add message callback.""" + for req in reqs: + req_data = req.data + if req_data.get('input_multimodals', None) is None: + continue + elif self.input_processor is None: + logger.warning('Do not support Multimodal inputs.') + continue + input_ids = req_data['token_ids'] + input_multimodals = req_data['input_multimodals'] + if len(input_multimodals) == 0: + req_data['input_multimodals'] = None + continue + result = self.input_processor.preprocess_input( + input_ids, input_multimodals) + + input_ids = result.input_ids + input_multimodals = result.input_multimodals + + req_data['token_ids'] = input_ids + req_data['input_multimodals'] = input_multimodals + + if len(reqs) > 0: + self._add_message(reqs) + + def _add_message(self, reqs): def __update_bad_words(msg): """update bad words.""" @@ -346,8 +369,7 @@ def __update_max_new_tokens(msg): for req in reqs: session_id = req.data['session_id'] if session_id not in self.scheduler.sessions: - self._response(ResponseType.SESSION_NOT_EXIST, req.sender_id, - req.req_id) + self._response(req.resp, ResponseType.SESSION_NOT_EXIST) continue session_id = req.data['session_id'] sess = self.scheduler.sessions[session_id] @@ -360,11 +382,8 @@ def __update_max_new_tokens(msg): sampling_param=req.data['sampling_param'], adapter_name=req.data['adapter_name'], return_logits=req.data.get('return_logits', False), + multimodals=req.data.get('input_multimodals'), input_embeddings=req.data.get('input_embeddings'), - mrope_position_ids=req.data.get('mrope_position_ids'), - mrope_position_delta=req.data.get('mrope_position_delta'), - cross_attention_states=req.data.get( - 'cross_attention_states'), ) msg = next(iter(sess.sequences.values())) __update_bad_words(msg) @@ -372,9 +391,11 @@ def __update_max_new_tokens(msg): self.scheduler.add_sequence(msg) else: msg = next(iter(sess.sequences.values())) - msg.update_token_ids(req.data['token_ids'], - req.data.get('input_embeddings'), - req.data.get('cross_attention_states')) + msg.update_token_ids( + req.data['token_ids'], + multimodals=req.data.get('input_multimodals'), + embeddings=req.data.get('input_embeddings'), + ) msg.num_new_tokens = 0 msg.sampling_param = req.data['sampling_param'] msg.return_logits = req.data.get('return_logits', False) @@ -382,8 +403,7 @@ def __update_max_new_tokens(msg): __update_bad_words(msg) __update_max_new_tokens(msg) - msg.sender_id = req.sender_id - msg.req_id = req.req_id + msg.resp = req.resp @property def model_config(self): @@ -420,7 +440,6 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): seq_length = self._seq_length_buf[:batch_size] max_q_seq_length = seq_length.max().item() - # TODO: get block offsets is slow when block_size = 1 block_offsets = self.scheduler.get_block_tables(messages) block_offsets = _tensorlize_block_offsets(block_offsets) @@ -438,13 +457,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): num_ignored_history = [msg.num_ignored_history for msg in messages] num_ignored_history = torch.tensor(num_ignored_history) - def __get_cogvlm_image_info(): - """Get cogvlm history image info for position ids.""" - history_image_nums = torch.LongTensor( - [msg.history_image_num for msg in messages]) - history_image_token_lengths = torch.LongTensor( - [msg.history_image_token_len for msg in messages]) - return history_image_nums, history_image_token_lengths + model_metas = [msg.model_meta for msg in messages] def __get_vlm_embeddings(): """get vlm input embeddings and indexings.""" @@ -469,25 +482,9 @@ def __get_vlm_embeddings(): return (input_embeddings, input_embedding_indexing, input_embedding_ranges) - def __get_mrope_inputs(): - """get multimodal rotary position inputs.""" - position_ids = [msg.mrope_position_ids for msg in messages] - deltas = [msg.mrope_position_delta for msg in messages] - return MRopeModelInputs(position_ids=position_ids, deltas=deltas) - # for inputs with embeddings history_image_nums = None history_image_token_lengths = None - # only for cogvlm - if self.model_config.cogvlm_style: - (history_image_nums, - history_image_token_lengths) = __get_cogvlm_image_info() - # only for qwen2_vl - mrope_inputs = None - has_mrope_params = any( - [msg.mrope_position_ids is not None for msg in messages]) - if has_mrope_params: - mrope_inputs = __get_mrope_inputs() input_embeddings = None input_embedding_indexing = None @@ -498,25 +495,40 @@ def __get_mrope_inputs(): (input_embeddings, input_embedding_indexing, input_embedding_ranges) = __get_vlm_embeddings() + input_multimodals = None + has_multimodal = any( + [not msg.history_multimodals.empty() for msg in messages]) + if has_multimodal: + has_multimodal = False + input_multimodals = [ + msg.get_input_multimodals() for msg in messages + ] + for input_mm in input_multimodals: + for val in input_mm.values(): + if len(val) > 0: + has_multimodal = True + break + if has_multimodal: + break + vision_embedding_inputs = None - if has_embedding or history_image_nums is not None: + if has_embedding or has_multimodal or history_image_nums is not None: vision_embedding_inputs = VisionModelInputs( history_lengths=history_lengths, history_image_nums=history_image_nums, history_image_token_lengths=history_image_token_lengths, input_embeddings=input_embeddings, input_embedding_indexing=input_embedding_indexing, - input_embedding_ranges=input_embedding_ranges) - - # only for mllama - cross_attention_states = None - history_cross_kv_seqlens = None - if any([msg.cross_attention_states is not None for msg in messages]): - cross_attention_states = [ - msg.cross_attention_states for msg in messages - ] - history_cross_kv_seqlens = torch.tensor( - [msg.history_cross_kv_seqlens for msg in messages]) + input_embedding_ranges=input_embedding_ranges, + input_multimodals=input_multimodals) + + # cross + cross_length = torch.tensor([msg.num_cross for msg in messages]) + history_cross_length = torch.tensor( + [msg.num_history_cross for msg in messages]) + if (cross_length + history_cross_length).max().item() == 0: + cross_length = None + history_cross_length = None return ModelInputs( input_ids=input_ids, @@ -527,9 +539,9 @@ def __get_mrope_inputs(): num_ignored_history=num_ignored_history, local_adapter_ids=local_adapter_ids, vision_inputs=vision_embedding_inputs, - mrope_inputs=mrope_inputs, - cross_attention_states=cross_attention_states, - history_cross_kv_seqlens=history_cross_kv_seqlens, + cross_length=cross_length, + history_cross_length=history_cross_length, + model_metas=model_metas, ) def _batch_stopping_criteria(self, token_ids: torch.Tensor, @@ -550,11 +562,12 @@ def _batch_stopping_criteria(self, token_ids: torch.Tensor, return stopped, num_appendable_ids @logging_timer('SamplingLogits', logger) - def async_sampling_logits(self, logits: torch.Tensor, - all_ids: torch.Tensor, - guided_input_ids: torch.Tensor, - sampling_inputs: SamplingInputs, - inputs: ModelInputs, ignore_eos: torch.Tensor): + async def async_sampling_logits(self, logits: torch.Tensor, + all_ids: torch.Tensor, + guided_input_ids: torch.Tensor, + sampling_inputs: SamplingInputs, + inputs: ModelInputs, + ignore_eos: torch.Tensor): """sampling logits.""" def __get_last_logits(): @@ -569,7 +582,8 @@ def __get_last_logits(): split_logits = __get_last_logits() logits_processor = FusedLogitsProcessor(sampling_inputs, ignore_eos, self.tokenizer.model.model) - logits = logits_processor(all_ids, guided_input_ids, split_logits) + logits = await logits_processor(all_ids, guided_input_ids, + split_logits) next_token_ids = logits_processor.sampling(logits) return next_token_ids @@ -587,20 +601,24 @@ def extract_tokens(self, token_ids, eos_token_ids): @logging_timer('UpdateRunning', logger) def update_running(self, running: SeqList, next_token_ids: torch.Tensor, - stopped: torch.Tensor): + stopped: torch.Tensor, model_metas: List[Dict[str, + Any]]): """update scheduler.""" + if model_metas is None: + model_metas = [None] * len(running) next_token_ids = next_token_ids.numpy() - eos_token_id = self.model_config.eos_token_id - for token, msg, stop in zip(next_token_ids, running, stopped): + for token, msg, stop, model_meta in zip(next_token_ids, running, + stopped, model_metas): if msg.status != MessageStatus.RUNNING: continue + eos_token_id = self.model_config.eos_token_id update_token, eos_stop = self.extract_tokens(token, eos_token_id) stop = stop or eos_stop if stop: update_token = _EMPTY_TOKEN else: msg.num_new_tokens += len(update_token) - msg.update_token_ids(update_token) + msg.update_token_ids(update_token, model_meta=model_meta) if stop: msg.status = MessageStatus.STOPPED @@ -666,12 +684,14 @@ async def __long_context_single_forward(inputs): batch_size = seq_len.size(0) assert batch_size == 1 - new_inputs = inputs.split(max_prefill_token_num, - self.cache_config.block_size) + new_inputs = inputs.split(max_prefill_token_num) + model_metas = new_inputs[0].model_metas output_gather = _OutputGather(max_seq_len) for inp in new_inputs: + inp.model_metas = model_metas tmp_out = await __forward(inp) + model_metas = tmp_out.get('model_metas') output_gather.gather(tmp_out) tmp_out.pop('hidden_states', None) tmp_out['hidden_states'] = output_gather.get_output() @@ -703,33 +723,12 @@ async def __long_context_single_forward(inputs): ret['logits'] = logits return ret - def _make_infer_outputs(self, next_token_ids: torch.LongTensor, - logits: torch.Tensor, stopped: torch.Tensor, - event: torch.cuda.Event): + async def _make_infer_outputs(self, next_token_ids: torch.LongTensor, + logits: torch.Tensor, stopped: torch.Tensor, + model_metas: List[Dict[str, Any]], + event: torch.cuda.Event): """make infer output.""" - def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence, - stopped: bool): - """check if output is necessary.""" - if isinstance(token, list): - idx = len(token) - for i, t in enumerate(token): - if t == -1: - idx = i - break - if stopped: - idx = min( - idx, - msg.sampling_param.max_new_tokens - msg.num_new_tokens) - token = token[:idx] - else: - if stopped: - return [] - if token in msg.sampling_param.stop_words: - return [] - token = [token] - return token - def __get_q_start_loc(): inputs = self._inputs seq_length = inputs.seq_length @@ -739,15 +738,16 @@ def __get_q_start_loc(): else: return seq_length.cumsum(0) - seq_length + while not event.query(): + await asyncio.sleep(0.001) with torch.cuda.stream(self._output_stream): - event.wait() next_token_ids = next_token_ids.cpu() stopped = stopped.cpu() running = self._running is_run = [seq.status == MessageStatus.RUNNING for seq in running] stopped = stopped.tolist() - self.update_running(running, next_token_ids, stopped) + self.update_running(running, next_token_ids, stopped, model_metas) # generate output next_token_ids = next_token_ids.tolist() @@ -756,16 +756,15 @@ def __get_q_start_loc(): for idx, msg in enumerate(running): if not is_run[idx]: continue - token_ids = __get_out_token_ids(next_token_ids[idx], msg, - stopped[idx]) + token_ids = msg.all_ids[-msg.num_new_tokens:] finish = msg.status == MessageStatus.STOPPED if not finish and len(token_ids) == 0: continue session_id = msg.session_id + resp = msg.resp out = InferOutput( session_id=session_id, - sender_id=msg.sender_id, - req_id=msg.req_id, + resp=resp, finish=finish, token_ids=token_ids, ) @@ -805,8 +804,7 @@ def __update_inputs(next_token_ids): logger.debug(': ' f'batch_size={inputs.seq_length.size(0)} ' f'num_tokens={inputs.input_ids.size(-1)}') - if self.gpu_count == 1: - inputs = inputs.to_device('cuda') + inputs = inputs.to_device('cuda') is_decoding = inputs.is_decoding if all_ids is not None: all_ids = all_ids.cuda() @@ -827,7 +825,7 @@ def __update_inputs(next_token_ids): logits = logits[0] # [bs, seq, prob] -> [seq, prob] # sampling - next_token_ids = self.async_sampling_logits( + next_token_ids = await self.async_sampling_logits( logits, all_ids, guided_input_ids, sampling_inputs, inputs, num_ignore_eos > 0) num_ignore_eos = num_ignore_eos - 1 @@ -873,13 +871,16 @@ def __update_inputs(next_token_ids): next_token_ids, sampling_inputs.stop_words, num_appendable_ids) # send output + model_metas = output.get('model_metas') finish = (idx == loop_count - 1) finish = finish or _check_finish(self.scheduler, idx) event = torch.cuda.Event() event.record() - output = (next_token_ids, logits, stopped, event) + output = (next_token_ids, logits, stopped, model_metas, event) output_que.put_nowait((finish, output)) + inputs.model_metas = model_metas + if finish: break @@ -889,9 +890,28 @@ def __update_inputs(next_token_ids): swap_out_map = dict() __update_inputs(next_token_ids) + def _set_has_runable_event(self, has_runable_event: asyncio.Event): + """set has runable event.""" + if self.scheduler.has_unfinished(): + has_runable_event.set() + else: + has_runable_event.clear() + + @torch.inference_mode() + async def _async_loop_preprocess_message(self, + forward_event: asyncio.Event, + has_runable_event: asyncio.Event): + """preprocess msg.""" + while True: + if self.scheduler.has_unfinished(): + await forward_event.wait() + await self.req_manager.step() + self._set_has_runable_event(has_runable_event) + @torch.inference_mode() async def _async_loop_background(self, in_que: asyncio.Queue, - out_que: asyncio.Queue): + out_que: asyncio.Queue, + forward_event: asyncio.Event): """async loop background.""" def __gather_all_ids(seqs: SeqList, sampling_inputs: SamplingInputs): @@ -952,66 +972,52 @@ def __need_logits(seqs: SeqList): while True: is_prefill, scheduler_output = await in_que.get() - try: - running = scheduler_output.running - swap_in_map = scheduler_output.swap_in_map - swap_out_map = scheduler_output.swap_out_map - prefill_interval = self.scheduler_config.prefill_interval - loop_count = 1 if is_prefill else (prefill_interval - 1) - assert len(running) > 0 - - # create inputs - inputs = self.create_model_inputs(running, is_prefill) - sampling_inputs = SamplingInputs.from_sampling_params(running) - all_ids = __gather_all_ids(running, sampling_inputs) - guided_input_ids = __gather_guided_input_ids( - running, sampling_inputs) - num_appendable_ids = __get_num_appendable_ids(running) - num_ignore_eos = __get_num_ignore_eos(running) - return_logits = __need_logits(running) - - self._running = running - self._inputs = inputs - - await self._async_step_background( - inputs=inputs, - swap_in_map=swap_in_map, - swap_out_map=swap_out_map, - all_ids=all_ids, - guided_input_ids=guided_input_ids, - sampling_inputs=sampling_inputs, - num_appendable_ids=num_appendable_ids, - num_ignore_eos=num_ignore_eos, - loop_count=loop_count, - return_logits=return_logits, - output_que=out_que, - ) - except Exception as e: - out_que.put_nowait((True, e)) - finally: - in_que.task_done() - - @torch.inference_mode() - async def _async_loop(self): - """Main loop of the engine. + running = scheduler_output.running + swap_in_map = scheduler_output.swap_in_map + swap_out_map = scheduler_output.swap_out_map + prefill_interval = self.scheduler_config.prefill_interval + loop_count = 1 if is_prefill else (prefill_interval - 1) + assert len(running) > 0 + + # create inputs + inputs = self.create_model_inputs(running, is_prefill) + sampling_inputs = SamplingInputs.from_sampling_params(running) + all_ids = __gather_all_ids(running, sampling_inputs) + guided_input_ids = __gather_guided_input_ids( + running, sampling_inputs) + num_appendable_ids = __get_num_appendable_ids(running) + num_ignore_eos = __get_num_ignore_eos(running) + return_logits = __need_logits(running) + + self._running = running + self._inputs = inputs + + forward_event.clear() + await self._async_step_background( + inputs=inputs, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map, + all_ids=all_ids, + guided_input_ids=guided_input_ids, + sampling_inputs=sampling_inputs, + num_appendable_ids=num_appendable_ids, + num_ignore_eos=num_ignore_eos, + loop_count=loop_count, + return_logits=return_logits, + output_que=out_que, + ) + forward_event.set() - Each engine instance would communicate with the engine by queue. - """ - prefill_interval = self.scheduler_config.prefill_interval - in_que = asyncio.Queue() - out_que = asyncio.Queue() - loop_background = asyncio.get_event_loop().create_task( - self._async_loop_background(in_que, out_que), - name='MainLoopBackground') - loop_background.add_done_callback(_raise_exception_on_finish) + async def _async_send_responses(self, que: asyncio.Queue, + forward_event: asyncio.Event): + """send responses.""" def __send_resp(out: InferOutput): """send response.""" resp_type = (ResponseType.FINISH if out.finish else ResponseType.SUCCESS) - self._response(resp_type, - sender_id=out.sender_id, - req_id=out.req_id, + self._response(out.resp, + resp_type, data=dict(token_ids=out.token_ids, logits=out.logits)) @@ -1020,9 +1026,89 @@ def __send_resps(step_outputs: Dict[int, InferOutput]): for out in step_outputs.values(): __send_resp(out) + while True: + resps = await que.get() + if self.scheduler.has_unfinished(): + await forward_event.wait() + __send_resps(resps) + + @staticmethod + def _add_loop_tasks_done_callback(tasks: List[asyncio.Task]): + """add loop tasks done callback.""" + + def __task_callback(task: asyncio.Task) -> None: + """raise exception on finish.""" + task_name = task.get_name() + try: + task.result() + except asyncio.CancelledError: + logger.debug(f'Task <{task_name}> cancelled.') + return + except Exception: + logger.exception(f'Task <{task_name}> failed') + for task in tasks: + if not task.cancelled(): + task.cancel() + + for task in tasks: + task.add_done_callback(__task_callback) + + @torch.inference_mode() + async def _async_loop(self): + """Main loop of the engine. + + Each engine instance would communicate with the engine by queue. + """ + event_loop = asyncio.get_event_loop() + prefill_interval = self.scheduler_config.prefill_interval + + # forward task + in_que = asyncio.Queue() + out_que = asyncio.Queue() + forward_event = asyncio.Event() + forward_event.set() + loop_background = event_loop.create_task(self._async_loop_background( + in_que, out_que, forward_event), + name='MainLoopBackground') + + # preprocess task + has_runable_event = asyncio.Event() + loop_msg_proc = event_loop.create_task( + self._async_loop_preprocess_message(forward_event, + has_runable_event), + name='MainLoopPreprocessMessage') + + # response task + resp_que = asyncio.Queue() + loop_send_resp = event_loop.create_task(self._async_send_responses( + resp_que, forward_event), + name='MainLoopResponse') + + loop_main = asyncio.current_task() + loop_tasks: List[asyncio.Task] = [ + loop_main, loop_background, loop_msg_proc, loop_send_resp + ] + self._add_loop_tasks_done_callback(loop_tasks) + + def __do_prefill(): + # decoding if no waiting + if not self.scheduler.has_waiting(): + return False + num_running = self.scheduler.num_running() + num_waiting = self.scheduler.num_waiting() + max_batches = self.scheduler_config.max_batches + # prefill if too much waiting + if num_waiting >= 4: + return True + # prefill if no enough running + if num_running < max_batches * 0.5: + return True + # decoding + return False + async def __step(): """step decoding.""" - prefill = self.scheduler.has_waiting() + prefill = __do_prefill() schedule_output = self.scheduler.schedule( is_prefill=prefill, prealloc_size=prefill_interval) # schedule decoding if no valid prefill reqs. @@ -1036,29 +1122,13 @@ async def __step(): in_que.put_nowait((prefill, schedule_output)) finish = False while not finish: - if self.req_manager.has_requests(): - self.req_manager.step() finish, out = await out_que.get() - try: - if isinstance(out, Exception): - raise out - next_token_ids, logits, stopped, event = out - step_outputs = self._make_infer_outputs( - next_token_ids, logits, stopped, event) - __send_resps(step_outputs) - except Exception as e: - raise e - finally: - out_que.task_done() + step_outputs = await self._make_infer_outputs(*out) + self._set_has_runable_event(has_runable_event) + resp_que.put_nowait(step_outputs) while True: - if self.req_manager.has_requests(): - self.req_manager.step() - - if not self.scheduler.has_unfinished(): - await asyncio.sleep(0.01) - continue - + await has_runable_event.wait() await __step() async def async_loop(self): @@ -1077,78 +1147,3 @@ def create_instance(self, cuda_stream_id=0): """ from .engine_instance import EngineInstance return EngineInstance(self) - - async def async_batched_infer( - self, - session_ids: List[int], - token_ids: List[List[int]] = None, - gen_config: GenerationConfig = None, - adapter_names: List[str] = None, - keep_cache: bool = False, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None): - """Send inference request. - - Args: - session_ids (List[int]): The session id. - token_ids (List[int]): The input token ids. - gen_config (GenerationConfig): The sampling parameters. - adapter_names (List[str]): The name of the adapters. - keep_cache (bool): Keep kv cache after infer. - - Returns: - int: Error flags. 0 if success. - List[int]: The streaming output tokens. - int: The number of the output tokens. - """ - return await self.engine_instance.async_batched_infer( - session_ids=session_ids, - token_ids=token_ids, - gen_config=gen_config, - adapter_names=adapter_names, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - keep_cache=keep_cache) - - def batched_infer( - self, - session_ids: List[int], - token_ids: List[List[int]] = None, - gen_config: GenerationConfig = None, - adapter_names: List[str] = None, - keep_cache: bool = False, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None): - """batched infer.""" - return self.engine_instance.batched_infer( - session_ids=session_ids, - token_ids=token_ids, - gen_config=gen_config, - adapter_names=adapter_names, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - keep_cache=keep_cache) - - async def async_add_session(self, session_id: int): - """Add new session.""" - return await self.engine_instance._async_try_add_session(session_id) - - def add_session(self, session_id: int): - """Add new session.""" - return self.engine_instance._try_add_session(session_id) - - async def async_cancel(self, session_id: int): - """Stop the given session.""" - return await self.engine_instance.async_cancel(session_id) - - def cancel(self, session_id: int): - """Add new session.""" - return self.engine_instance.cancel(session_id) - - async def async_end(self, session_id: int): - """End the given session.""" - return await self.engine_instance.async_end(session_id) - - def end(self, session_id: int): - """Add new session.""" - return self.engine_instance.end(session_id) diff --git a/lmdeploy/pytorch/engine/engine_checker.py b/lmdeploy/pytorch/engine/engine_checker.py new file mode 100644 index 0000000000..5b0cc9865c --- /dev/null +++ b/lmdeploy/pytorch/engine/engine_checker.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.messages import PytorchEngineConfig + +from ..check_env.adapter import AdapterChecker +from ..check_env.base import BaseChecker +from ..check_env.model import ModelChecker +from ..check_env.torch import TorchChecker +from ..check_env.transformers import TransformersChecker + + +class EngineChecker(BaseChecker): + """check transformers is available.""" + + def __init__(self, + model_path: str, + engine_config: PytorchEngineConfig, + trust_remote_code: bool = True, + logger=None): + super().__init__(logger) + logger = self.get_logger() + + self.engine_config = engine_config + + dtype = engine_config.dtype + device_type = engine_config.device_type + + # pytorch + torch_checker = TorchChecker(logger=logger) + + if device_type == 'cuda': + # triton + from ..check_env.triton import TritonChecker + triton_checker = TritonChecker(logger=logger) + triton_checker.register_required_checker(torch_checker) + self.register_required_checker(triton_checker) + else: + # deeplink + from ..check_env.deeplink import DeeplinkChecker + dl_checker = DeeplinkChecker(device_type, logger=logger) + self.register_required_checker(dl_checker) + self.register_required_checker(torch_checker) + + # transformers + + # model + trans_checker = TransformersChecker() + model_checker = ModelChecker(model_path=model_path, + trust_remote_code=trust_remote_code, + dtype=dtype, + device_type=device_type, + logger=logger) + model_checker.register_required_checker(torch_checker) + model_checker.register_required_checker(trans_checker) + self.register_required_checker(model_checker) + + # adapters + adapters = engine_config.adapters + if adapters is not None: + adapter_paths = list(adapters.values()) + for adapter in adapter_paths: + adapter_checker = AdapterChecker(adapter, logger=logger) + self.register_required_checker(adapter_checker) + + def check(self): + """check.""" + engine_config = self.engine_config + + if engine_config.thread_safe: + self.log_and_exit( + mod_name='Engine', + message='thread safe mode is no longer supported.\n' + 'Read https://github.com/InternLM/lmdeploy/blob/main/docs/en/advance/pytorch_multithread.md for more details.', # noqa: E501 + ) + + if engine_config.max_batch_size <= 0: + self.log_and_exit( + mod_name='Engine', + message='max_batch_size should be' + f' greater than 0, but got {engine_config.max_batch_size}') diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index 455ab1ccb3..5cf1366783 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -1,16 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List +from typing import Any, Dict, List from lmdeploy.messages import EngineOutput, GenerationConfig from lmdeploy.utils import get_logger -from ..messages import (InputEmbeddingRangeType, InputEmbeddings, - InputEmbeddingType, SamplingParam) +from ..messages import SamplingParam from .engine import Engine from .request import RequestSender, RequestType, Response, ResponseType logger = get_logger('lmdeploy') +InputMultiModalType = List[Dict[str, Any]] + def _check_resp(resp: Response, state: ResponseType, warning_msg: str = None): """check if response has state.""" @@ -42,8 +43,8 @@ async def async_try_add_session(req_sender: RequestSender, session_id: int): async def async_end(req_sender: RequestSender, session_id: int): """End the given session.""" - await req_sender.async_send_async( - RequestType.END_SESSION, dict(session_id=session_id, response=False)) + req_sender.send_async(RequestType.END_SESSION, + dict(session_id=session_id, response=False)) async def async_cancel(req_sender: RequestSender, session_id: int): @@ -114,15 +115,13 @@ def _try_add_session(self, session_id: int): """ return try_add_session(self.req_sender, session_id) - async def async_stream_infer( - self, - session_id: int, - input_ids: List[int], - gen_config: GenerationConfig = None, - adapter_name: str = None, - input_embeddings: InputEmbeddingType = None, - input_embedding_ranges: InputEmbeddingRangeType = None, - **kwargs): + async def async_stream_infer(self, + session_id: int, + input_ids: List[int], + gen_config: GenerationConfig = None, + multimodal: InputMultiModalType = None, + adapter_name: str = None, + **kwargs): """Send stream inference request. Args: @@ -141,52 +140,37 @@ async def async_stream_infer( return gen_config = gen_config or GenerationConfig() sampling_param = SamplingParam.from_gen_config(gen_config=gen_config) - await self.req_sender.async_send_async( - RequestType.ADD_SESSION, dict(session_id=session_id, - response=False)) - input_embeddings_new: List[InputEmbeddings] = None - if input_embeddings is not None and len(input_embeddings) > 0: - assert len(input_embeddings) == len(input_embedding_ranges) - input_embeddings_new = [ - InputEmbeddings(emb, rg[0], rg[1]) - for emb, rg in zip(input_embeddings, input_embedding_ranges) - ] - msg = dict(token_ids=input_ids, - session_id=session_id, - sampling_param=sampling_param, - adapter_name=adapter_name, - input_embeddings=input_embeddings_new, - mrope_position_ids=kwargs.get('mrope_position_ids'), - mrope_position_delta=kwargs.get('mrope_position_delta'), - cross_attention_states=kwargs.get('cross_attention_states')) - req_id = await self.req_sender.async_send_async( - RequestType.ADD_MESSAGE, msg) + self.req_sender.send_async(RequestType.ADD_SESSION, + dict(session_id=session_id, response=False)) + msg = dict( + token_ids=input_ids, + session_id=session_id, + sampling_param=sampling_param, + adapter_name=adapter_name, + input_multimodals=multimodal, + ) + resp = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg) - token_ids = [] while True: - resp = await self.req_sender.async_recv(req_id) + resp = await self.req_sender.async_recv(resp) - if resp.req_id != req_id: - continue if resp.type == ResponseType.SUCCESS: - token_ids += resp.data['token_ids'] + token_ids = resp.data['token_ids'].tolist() yield EngineOutput(resp.type, token_ids, len(token_ids)) elif resp.type == ResponseType.FINISH: - token_ids += resp.data['token_ids'] + token_ids = resp.data['token_ids'].tolist() yield EngineOutput(resp.type, token_ids, len(token_ids)) break else: yield EngineOutput(resp.type, [], 0) break - async def async_infer( - self, - session_id: int, - input_ids: List[int] = None, - gen_config: GenerationConfig = None, - input_embeddings: InputEmbeddingType = None, - input_embedding_ranges: InputEmbeddingRangeType = None, - **kwargs): + async def async_infer(self, + session_id: int, + input_ids: List[int] = None, + multimodal: InputMultiModalType = None, + gen_config: GenerationConfig = None, + **kwargs): """Send inference request. Args: @@ -200,13 +184,11 @@ async def async_infer( int: The number of the output tokens. """ token_ids = [] - async for outputs in self.async_stream_infer( - session_id, - input_ids, - gen_config=gen_config, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - **kwargs): + async for outputs in self.async_stream_infer(session_id, + input_ids, + multimodal=multimodal, + gen_config=gen_config, + **kwargs): status, tmp_ids = outputs.status, outputs.token_ids if status not in [ResponseType.SUCCESS, ResponseType.FINISH]: return EngineOutput(status, token_ids, len(token_ids)) @@ -217,10 +199,9 @@ async def async_infer( def stream_infer(self, session_id: int, input_ids: List[int], + multimodal: InputMultiModalType = None, gen_config: GenerationConfig = None, adapter_name: str = None, - input_embeddings: InputEmbeddingType = None, - input_embedding_ranges: InputEmbeddingRangeType = None, **kwargs): """Send stream inference request. @@ -241,14 +222,12 @@ def stream_infer(self, def __call_async(): """call async.""" - coro_gen = self.async_stream_infer( - session_id, - input_ids, - gen_config, - adapter_name, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - **kwargs) + coro_gen = self.async_stream_infer(session_id, + input_ids, + multimodal=multimodal, + gen_config=gen_config, + adapter_name=adapter_name, + **kwargs) while True: try: yield self.req_sender.run_until_complete( @@ -256,53 +235,13 @@ def __call_async(): except StopAsyncIteration: break - if not self.req_sender.is_thread_safe(): - yield from __call_async() - return - - gen_config = gen_config or GenerationConfig() - sampling_param = SamplingParam.from_gen_config(gen_config=gen_config) - self.req_sender.send_async(RequestType.ADD_SESSION, - dict(session_id=session_id, response=False)) - input_embeddings_new: List[InputEmbeddings] = None - if input_embeddings is not None and len(input_embeddings) > 0: - assert len(input_embeddings) == len(input_embedding_ranges) - input_embeddings_new = [ - InputEmbeddings(emb, rg[0], rg[1]) - for emb, rg in zip(input_embeddings, input_embedding_ranges) - ] - msg = dict( - token_ids=input_ids, - session_id=session_id, - sampling_param=sampling_param, - adapter_name=adapter_name, - input_embeddings=input_embeddings_new, - ) - req_id = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg) - - token_ids = [] - while True: - resp = self.req_sender.recv(req_id) - - if resp.req_id != req_id: - continue - if resp.type == ResponseType.SUCCESS: - token_ids += resp.data['token_ids'] - yield EngineOutput(resp.type, token_ids, len(token_ids)) - elif resp.type == ResponseType.FINISH: - token_ids += resp.data['token_ids'] - yield EngineOutput(resp.type, token_ids, len(token_ids)) - break - else: - yield EngineOutput(resp.type, [], 0) - break + yield from __call_async() def infer(self, session_id: int, input_ids: List[int] = None, + multimodal: InputMultiModalType = None, gen_config: GenerationConfig = None, - input_embeddings: InputEmbeddingType = None, - input_embedding_ranges: InputEmbeddingRangeType = None, **kwargs): """Send inference request. @@ -317,13 +256,11 @@ def infer(self, int: The number of the output tokens. """ token_ids = [] - for outputs in self.stream_infer( - session_id, - input_ids, - gen_config=gen_config, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - **kwargs): + for outputs in self.stream_infer(session_id, + input_ids, + multimodal=multimodal, + gen_config=gen_config, + **kwargs): status, tmp_ids = outputs.status, outputs.token_ids if status not in [ResponseType.SUCCESS, ResponseType.FINISH]: return EngineOutput(status, token_ids, len(token_ids)) @@ -331,127 +268,6 @@ def infer(self, return EngineOutput(0, token_ids, len(token_ids)) - async def async_batched_infer( - self, - session_ids: List[int], - token_ids: List[List[int]] = None, - gen_config: GenerationConfig = None, - adapter_names: List[str] = None, - keep_cache: bool = False, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None, - ): - """Send inference request. - - Args: - session_ids (List[int]): The session id. - token_ids (List[int]): The input token ids. - gen_config (GenerationConfig): The sampling parameters. - adapter_names (List[str]): The name of the adapters. - keep_cache (bool): Keep kv cache after infer. - - Returns: - int: Error flags. 0 if success. - List[int]: The streaming output tokens. - int: The number of the output tokens. - """ - batch_size = len(token_ids) - assert len(session_ids) == batch_size - if adapter_names is not None: - assert len(adapter_names) == batch_size - else: - adapter_names = [None for _ in range(batch_size)] - - if input_embeddings is not None: - assert len(input_embeddings) == batch_size - assert len(input_embedding_ranges) == batch_size - else: - input_embeddings = [None] * batch_size - input_embedding_ranges = [None] * batch_size - - async def _add_sessions(session_ids): - for session_id in session_ids: - await self._async_try_add_session(session_id) - - async def _add_messages(session_ids, token_ids, adapter_names, - input_embeddings, input_embedding_ranges): - add_msgs = [] - sampling_param = SamplingParam.from_gen_config(gen_config) - for session_id, token_id, adapter_name, input_emb, input_ranges in zip( # noqa: E501 - session_ids, token_ids, adapter_names, input_embeddings, - input_embedding_ranges): - cur_input_embeddings: List[InputEmbeddings] = None - if input_emb is not None and len(input_emb) > 0: - assert len(input_emb) == len(input_ranges) - cur_input_embeddings = [ - InputEmbeddings(emb, rg[0], rg[1]) - for emb, rg in zip(input_emb, input_ranges) - ] - msg = dict( - token_ids=token_id, - session_id=session_id, - sampling_param=sampling_param, - adapter_name=adapter_name, - input_embeddings=cur_input_embeddings, - ) - add_msgs.append(msg) - req_types = [RequestType.ADD_MESSAGE] * batch_size - req_ids = await self.req_sender.async_batched_send_async( - req_types, data=add_msgs) - return req_ids - - await _add_sessions(session_ids) - req_ids = await _add_messages(session_ids, token_ids, adapter_names, - input_embeddings, input_embedding_ranges) - - # receive messages - req_idx_map = dict(zip(req_ids, range(len(req_ids)))) - output_token_ids = [list() for _ in req_ids] - status = 0 - finish_count = batch_size - while finish_count: - resp = await self.req_sender.async_recv_any() - if resp.req_id not in req_ids: - continue - idx = req_idx_map[resp.req_id] - token_ids = output_token_ids[idx] - if resp.type == ResponseType.SUCCESS: - token_ids += resp.data['token_ids'] - elif resp.type == ResponseType.FINISH: - token_ids += resp.data['token_ids'] - if not keep_cache: - session_id = session_ids[idx] - await self.async_end(session_id=session_id) - finish_count -= 1 - else: - logger.error(f'Unexpected response: {resp.type}') - status = 1 - break - - output_token_len = [len(token_ids) for token_ids in output_token_ids] - return EngineOutput(status, output_token_ids, output_token_len) - - def batched_infer( - self, - session_ids: List[int], - token_ids: List[List[int]] = None, - gen_config: GenerationConfig = None, - adapter_names: List[str] = None, - keep_cache: bool = False, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None, - ): - """batched infer.""" - coro = self.async_batched_infer( - session_ids, - token_ids, - gen_config=gen_config, - adapter_names=adapter_names, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - keep_cache=keep_cache) - return self.req_sender.run_until_complete(coro) - async def async_end(self, session_id: int): """End the given session.""" return await async_end(self.req_sender, session_id) @@ -470,8 +286,7 @@ def cancel(self, session_id: int): def decode(self, input_ids, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None, + multimodal: List[InputMultiModalType] = None, steps: List[int] = None, sequence_start: bool = True, sequence_end: bool = True, @@ -481,10 +296,8 @@ def decode(self, Args: input_ids (numpy.ndarray): the batch of input token ids steps (List[int]): the offset of the k/v cache - input_embeddings (List[List[Union[torch.Tensor, np.ndarray]]]): - embeddings features - input_embedding_ranges: (List[List[Tuple[int, int]]]): - the begin/end offsets of input_embeddings to input_ids + multimodal (List[InputMultiModalType]): + multimodals inputs. sequence_start (bool): indicator for starting a sequence sequence_end (bool): indicator for ending a sequence adapter_names (List[str]): The name of the adapters. @@ -494,39 +307,30 @@ def decode(self, batch_size = len(input_ids) def __add_messages(session_ids, input_ids, adapter_names, - input_embeddings, input_embedding_ranges): + input_multimodals): add_msgs = [] sampling_param = SamplingParam(max_new_tokens=0) batch_size = len(input_ids) - if input_embeddings is None: - input_embeddings = [None] * batch_size - input_embedding_ranges = [None] * batch_size - for (session_id, token_id, adapter_name, input_emb, - input_ranges) in zip(session_ids, input_ids, adapter_names, - input_embeddings, - input_embedding_ranges): + if input_multimodals is None: + input_multimodals = [None] * batch_size + for (session_id, token_id, adapter_name, + in_mm) in zip(session_ids, input_ids, adapter_names, + input_multimodals): if len(token_id) > self.max_input_len: raise RuntimeError( f'Expect input length<={self.max_input_len} ' f'but get {len(token_id)}') - cur_input_embeddings: List[InputEmbeddings] = None - if input_emb is not None and len(input_emb) > 0: - assert len(input_emb) == len(input_ranges) - cur_input_embeddings = [ - InputEmbeddings(emb, rg[0], rg[1]) - for emb, rg in zip(input_emb, input_ranges) - ] msg = dict(token_ids=token_id, session_id=session_id, sampling_param=sampling_param, adapter_name=adapter_name, - input_embeddings=cur_input_embeddings, + input_multimodals=in_mm, return_logits=True) add_msgs.append(msg) req_types = [RequestType.ADD_MESSAGE] * batch_size - req_ids = self.req_sender.batched_send_async(req_types, - data=add_msgs) - return req_ids + resps = self.req_sender.batched_send_async(req_types, + data=add_msgs) + return resps if steps is not None: assert batch_size == len(steps) @@ -536,13 +340,6 @@ def __add_messages(session_ids, input_ids, adapter_names, else: adapter_names = [None] * batch_size - if input_embeddings is not None: - assert len(input_embeddings) == batch_size - assert len(input_embedding_ranges) == batch_size - else: - input_embeddings = [None] * batch_size - input_embedding_ranges = [None] * batch_size - session_ids = tuple(range(batch_size)) if sequence_start: for sid in session_ids: @@ -550,21 +347,14 @@ def __add_messages(session_ids, input_ids, adapter_names, dict(session_id=sid)) self._try_add_session(sid) - req_ids = __add_messages(session_ids, input_ids, adapter_names, - input_embeddings, input_embedding_ranges) - req_idx_map = dict(zip(req_ids, range(len(req_ids)))) - - finish_count = batch_size - ret = [None] * batch_size - while finish_count > 0: - resp = self.req_sender.recv_any() - if resp.req_id not in req_ids: - continue + resps = __add_messages(session_ids, input_ids, adapter_names, + multimodal) + ret = [] + for resp in resps: + resp = self.req_sender.recv(resp) assert resp.type == ResponseType.FINISH - idx = req_idx_map[resp.req_id] - ret[idx] = resp.data['logits'] - finish_count -= 1 + ret.append(resp.data['logits']) ret = pad_sequence(ret, True) diff --git a/lmdeploy/pytorch/engine/input_process.py b/lmdeploy/pytorch/engine/input_process.py new file mode 100644 index 0000000000..7f442e153b --- /dev/null +++ b/lmdeploy/pytorch/engine/input_process.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs + +TypeModelMetas = Dict[str, Any] + +InputMultiModalType = List[Dict[str, Any]] + + +@dataclass +class PreprocessInputResult: + """results of preprocess input.""" + input_ids: List[int] + input_multimodals: Optional[MultiModalInputs] = None + model_metas: Optional[TypeModelMetas] = None + + +class BaseModelInputProcessor(ABC): + """processor of model inputs.""" + + @abstractmethod + def preprocess_input(self, + input_ids: List[int], + input_mms: InputMultiModalType = None, + **kwargs) -> PreprocessInputResult: + """preprocess input.""" + raise NotImplementedError('Not implemented.') + + +class DefaultModelInputProcessor(BaseModelInputProcessor): + """default model input processor.""" + + def preprocess_input(self, + input_ids: List[int], + input_mms: MultiModalInputs = None, + **kwargs) -> PreprocessInputResult: + """preprocess input.""" + return PreprocessInputResult( + input_ids=input_ids, + input_multimodals=input_mms, + ) diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 24cb336d71..f7ca9c5116 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import asyncio import json from dataclasses import asdict, dataclass from typing import Dict, List, Optional, Tuple @@ -298,9 +299,15 @@ def __init__(self, self.ignore_eos = ignore_eos self.tokenizer = tokenizer - def __call__(self, all_ids: torch.LongTensor, - guided_input_ids: torch.LongTensor, - scores: torch.FloatTensor) -> torch.FloatTensor: + async def _wait_stream_once(self): + """wait stream once.""" + stream = torch.cuda.current_stream() + if not stream.query(): + await asyncio.sleep(0) + + async def __call__(self, all_ids: torch.LongTensor, + guided_input_ids: torch.LongTensor, + scores: torch.FloatTensor) -> torch.FloatTensor: r""" Args: all_ids (torch.LongTensor): All the token ids. @@ -320,6 +327,7 @@ def __call__(self, all_ids: torch.LongTensor, custom_logits_processors = self.sampling_inputs.logits_processors if any(custom_logits_processors): + await self._wait_stream_once() scores = _apply_custom_logits_processors(custom_logits_processors, all_ids, scores) @@ -343,8 +351,10 @@ def __call__(self, all_ids: torch.LongTensor, stop_mask = torch.where(self.ignore_eos[:, None], stop_mask, False) scores = _process_bad_words_(scores, stop_words, stop_mask) - scores = _guided_sampling(sampling_inputs.response_formats, scores, - guided_input_ids, self.tokenizer) + if guided_input_ids is not None: + await self._wait_stream_once() + scores = _guided_sampling(sampling_inputs.response_formats, scores, + guided_input_ids, self.tokenizer) return scores @torch.inference_mode() diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 03fb083265..9b98dfb1fe 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -135,21 +135,26 @@ def model_forward( stream = stream or torch.cuda.current_stream() with torch.cuda.stream(stream): # forward - inputs = inputs.to_device('cuda') ctx_mgr = model.ctx_mgr context = ctx_mgr.build_context( inputs=inputs, + model_config=cache_engine.model_config, world_size=world_size, kv_caches=cache_engine.gpu_cache, kv_quant_policy=cache_engine.cache_config.quant_policy, ) with ctx_mgr.context(context): + model_metas = None + model_metas = model.update_model_metas( + past_key_values=cache_engine.gpu_cache, + context=context, + ) input_dict = model.prepare_inputs_for_generation( past_key_values=cache_engine.gpu_cache, context=context, ) output = model(**input_dict) - return dict(hidden_states=output) + return dict(hidden_states=output, model_metas=model_metas) SwapMap = Dict[int, int] @@ -177,6 +182,10 @@ def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" raise NotImplementedError('Not implemented.') + def get_input_processor(self): + """get input processor.""" + raise NotImplementedError('Not implemented.') + class BaseModelAgent(AutoModelAgent): """Base model agent. @@ -288,8 +297,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_in_map=swap_in_map, swap_out_map=swap_out_map) await asyncio.sleep(0) - while not self.stream.query(): - await asyncio.sleep(0) return output async def score_proposal(self, inputs: ModelInputs, swap_in_map: SwapMap, @@ -358,6 +365,10 @@ def get_spec_logits(self, hidden_states_list: List[torch.Tensor]): """get logits of model output.""" return self.speculative_model.get_logits(hidden_states_list) + def get_input_processor(self): + """get input processor..""" + return self.patched_model.get_input_processor() + @torch.inference_mode() def _tp_build_model( @@ -443,14 +454,26 @@ def _broadcast_config(cache_config): return patched_model, cache_engine, cache_config -def _broadcast_inputs(rank: int, inputs: Any, stream: torch.cuda.Stream): +def _broadcast_inputs(rank: int, inputs: Any, group: dist.group, + stream: torch.cuda.Stream): """get input tensor parallel.""" # broadcast meta info if rank != 0: inputs = [None, None, None] + else: + device_inputs = inputs[0] + meta_inputs = device_inputs.to_device('meta') + inputs[0] = meta_inputs with torch.cuda.stream(stream): - dist.broadcast_object_list(inputs) + dist.broadcast_object_list(inputs, group=group) + if rank == 0: + device_inputs.broadcast() + else: + device_inputs = inputs[0].broadcast() + + inputs[0] = device_inputs + return inputs @@ -463,6 +486,7 @@ def _tp_model_loop( adapters: Dict[str, str], world_size: int, barrier: mp.Barrier, + cpu_group: dist.group, ): """Start model loops for tensor parallel model inference. @@ -488,11 +512,12 @@ def _tp_model_loop( while True: barrier.wait() inputs, swap_in_map, swap_out_map = _broadcast_inputs( - rank, None, stream) + rank, None, cpu_group, stream) cache_swapping(cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) + inputs = inputs.to_device('cuda') model_forward( patched_model, @@ -524,10 +549,13 @@ def _start_tp_process(proc_id: int, try: from lmdeploy.pytorch.check_env import check_env_deeplink check_env_deeplink(device_context.device_type) + timeout = timedelta(days=35600) dist.init_process_group('nccl', rank=rank, world_size=world_size, - timeout=timedelta(days=35600)) + timeout=timeout) + cpu_group = dist.new_group(timeout=timeout, backend='gloo') + kwargs['cpu_group'] = cpu_group dist_ctx = DistContext(rank=rank, world_size=world_size) torch.cuda.set_device(rank) with get_dist_manager().context(dist_ctx), get_device_manager( @@ -706,12 +734,15 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig, rank = 0 try: + timeout = timedelta(days=35600) dist.init_process_group('nccl', rank=rank, world_size=world_size, - timeout=timedelta(days=35600)) + timeout=timeout) + cpu_group = dist.new_group(timeout=timeout, backend='gloo') dist_ctx = DistContext(rank=rank, world_size=world_size) self._dist_ctx = dist_ctx + self._cpu_group = cpu_group except Exception as e: from traceback import print_exc logger.error(f'Rank[{rank}] failed.') @@ -782,7 +813,8 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, self.mp_bar.wait() rank = 0 _broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map], - self.stream) + self._cpu_group, self.stream) + cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) @@ -850,8 +882,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_in_map=swap_in_map, swap_out_map=swap_out_map) await asyncio.sleep(0) - while not self.stream.query(): - await asyncio.sleep(0) return output async def tree_decoding(self, inputs: ModelInputs, swap_in_map: SwapMap, @@ -897,6 +927,10 @@ def get_spec_logits(self, hidden_states_list: List[torch.Tensor]): """get logits of model output.""" return self.speculative_model.get_logits(hidden_states_list) + def get_input_processor(self): + """get input processor..""" + return self.patched_model.get_input_processor() + def _exit_handler(agent: TPModelAgent): if hasattr(agent, 'patched_model'): @@ -926,7 +960,7 @@ def build_model_agent(model_path: str, custom_module_map (str): customized nn module map """ model_config = ModelConfig.from_pretrained( - model_path, trust_remote_code=trust_remote_code, dtype=dtype) + model_path, trust_remote_code=trust_remote_code, dtype=dtype, tp=tp) model_config.custom_module_map = custom_module_map speculative_model_config = None if speculative_model is not None: diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py index 18bd2193d4..0d20deb907 100644 --- a/lmdeploy/pytorch/engine/request.py +++ b/lmdeploy/pytorch/engine/request.py @@ -2,8 +2,6 @@ import asyncio import enum from dataclasses import dataclass, field -from queue import Queue -from threading import Lock, Thread from typing import Any, Awaitable, Callable, Dict, List from lmdeploy.messages import ResponseType @@ -12,25 +10,6 @@ logger = get_logger('lmdeploy') -def _raise_exception_on_finish(task: asyncio.Task) -> None: - try: - task.result() - except asyncio.CancelledError: - return - except Exception as e: - logger.exception(f'Engine loop failed with error: {e}') - - -def _ignore_exception_on_finish(task: asyncio.Task) -> None: - try: - task.result() - except asyncio.CancelledError: - return - except Exception as exc: - logger.debug(f'task: {task.get_name()} ended.') - logger.debug(f'task: {task.get_name()} exception: {exc}') - - class RequestType(enum.Enum): """Request type.""" @@ -43,24 +22,24 @@ class RequestType(enum.Enum): @dataclass -class Request: - """Request.""" +class Response: + """Response.""" - type: RequestType + type: ResponseType sender_id: int - req_id: int + event: asyncio.Event data: Any = None + err_msg: str = '' @dataclass -class Response: - """Response.""" +class Request: + """Request.""" - type: ResponseType + type: RequestType sender_id: int - req_id: int data: Any = None - err_msg: str = '' + resp: Response = None ReqList = List[Request] @@ -85,28 +64,20 @@ class RequestSender: Args: sender_id (int): The id of the sender """ - sender_id: int manager: 'RequestManager' resp_dict: Dict[int, List[Response]] = field(default_factory=dict) - _next_req_id: int = 0 _resp_que: asyncio.Queue = None - _resp_thread_que: Queue = None - _thread_safe: bool = False @classmethod def new(cls, sender_id: int, manager: 'RequestManager'): """new.""" obj = cls(sender_id=sender_id, manager=manager) - obj._thread_safe = manager.is_thread_safe() return obj @property def resp_que(self): """response queue.""" - thread_safe = self.is_thread_safe() - if thread_safe: - return self.manager.responses if self._resp_que is not None: return self._resp_que if self.manager._loop_task is None: @@ -119,27 +90,11 @@ def req_que(self): """request queue.""" return self.manager.requests - @property - def resp_thread_que(self): - """response threadsafe queue.""" - if self._resp_thread_que is None: - self._resp_thread_que = Queue() - return self._resp_thread_que - - @property - def req_thread_que(self): - """request threadsafe queue.""" - return self.manager.thread_requests - @property def event_loop(self): """get event loop.""" return self.manager.event_loop - def is_thread_safe(self): - """is thread safe.""" - return self._thread_safe - def is_loop_alive(self): """is loop alive.""" return self.manager.is_loop_alive() @@ -148,203 +103,72 @@ def run_until_complete(self, future: Awaitable): """run untile complete.""" return self.manager.run_until_complete(future) - def _resp_get(self): - """resp_que.get.""" - timeout = 1.0 - que = self.resp_thread_que - not_empty = que.not_empty - with not_empty: - while not que._qsize(): - not_empty.wait(timeout) - return que.get_nowait() - - async def _async_resp_get(self): - """get resp. - - Different behavior in threadsafe mode. - """ - timeout = 1 - - async def __no_threadsafe_get(): - while True: - try: - return await asyncio.wait_for(self.resp_que.get(), timeout) - except asyncio.TimeoutError: - if not self.manager.is_loop_alive(): - logger.debug('Engine loop is not alive.') - exit(1) - continue - except Exception as e: - logger.exception( - f'sender[{self.sender_id}] get response failed: {e}') - raise e - - if self.is_thread_safe(): - ret = self._resp_get() - await asyncio.sleep(0) - return ret - else: - return await __no_threadsafe_get() - def _req_put(self, reqs: Any): - """req put.""" - self.req_thread_que.put(reqs) - - async def _async_req_put(self, reqs: Any): - """async rq_que put. - - Different behavior in threadsafe mode. - """ - if self.is_thread_safe(): - self._req_put(reqs) - await asyncio.sleep(0) - else: - await self.req_que.put(reqs) - - def _prefetch_resps(self): - """prefetch from resp que. - - Different behavior in threadsafe mode. - """ - if self.is_thread_safe(): - resp_que = self.resp_thread_que - else: - resp_que = self.resp_que - num_resps = resp_que.qsize() - for _ in range(num_resps): - resp: Response = resp_que.get_nowait() - req_id = resp.req_id - self._push_resp(req_id, resp) - - def _push_resp(self, req_id: int, resp: Response): - """push response.""" - self.resp_dict.setdefault(req_id, []) - self.resp_dict[req_id].append(resp) - - def _pop_resp(self, req_id: int, default: Any = None): - """pop response.""" - if req_id not in self.resp_dict: - return default - resps = self.resp_dict[req_id] - ret = resps.pop(0) - if len(resps) == 0: - self.resp_dict.pop(req_id) - return ret + """async rq_que put.""" + self.req_que.put_nowait(reqs) def _gather_request(self, req_types: List[RequestType], data: List[Any]): """gather requests.""" - if self.manager._loop_task is None and not self.is_thread_safe(): + if self.manager._loop_task is None: self.manager.create_loop_task() assert len(req_types) == len(data) - batch_size = len(req_types) - - req_ids = list(range(self._next_req_id, - self._next_req_id + batch_size)) - self._next_req_id += batch_size - - reqs = [ - Request(type=rtype, - sender_id=self.sender_id, - req_id=req_id, - data=rdata) - for req_id, rtype, rdata in zip(req_ids, req_types, data) - ] - return req_ids, reqs - async def async_batched_send_async(self, req_types: List[RequestType], - data: List[Any]): - """Batched send request asynchronize.""" - req_ids, reqs = self._gather_request(req_types, data) - await self._async_req_put(reqs) - return req_ids - - async def async_send_async(self, req_type: RequestType, data: Any): - """send request asynchronize.""" - return (await self.async_batched_send_async(req_types=[req_type], - data=[data]))[0] + reqs = [] + resps = [] + for rtype, rdata in zip(req_types, data): + event = asyncio.Event() + resp = Response(type=ResponseType.HANDLER_NOT_EXIST, + sender_id=self.sender_id, + event=event, + data=None, + err_msg=None) + req = Request(type=rtype, + sender_id=self.sender_id, + data=rdata, + resp=resp) + resps.append(resp) + reqs.append(req) + return resps, reqs def batched_send_async(self, req_types: List[RequestType], - data: List[Any]) -> List[int]: - """Batched send request asynchronize. - - Different behavior in threadsafe mode. - """ - if not self.is_thread_safe(): - coro = self.async_batched_send_async(req_types, data) - return self.run_until_complete(coro) - - req_ids, reqs = self._gather_request(req_types, data) + data: List[Any]): + """Batched send request asynchronize.""" + resps, reqs = self._gather_request(req_types, data) self._req_put(reqs) - return req_ids + return resps - def send_async(self, req_type: RequestType, data: Any) -> int: + def send_async(self, req_type: RequestType, data: Any): """send request asynchronize.""" return self.batched_send_async(req_types=[req_type], data=[data])[0] - async def async_recv_any(self) -> Response: - """receive any response.""" - self._prefetch_resps() - for req_id in self.resp_dict: - ret = self._pop_resp(req_id, default=None) - if ret is not None: - return ret - return await self._async_resp_get() - - def recv_any(self) -> Response: - """receive any response.""" - coro = self.async_recv_any() - return self.run_until_complete(coro) - - def recv_all(self, req_id: int, block: bool = True): - """revceive all response with req_id.""" - self._prefetch_resps() - resps = self.resp_dict.pop(req_id, []) - return resps - - async def async_recv(self, req_id: int) -> Response: + async def async_recv(self, resp: Response) -> Response: """receive response of given request id async.""" - ret = self._pop_resp(req_id, default=None) - if ret is not None: - return ret - - # check resp que - while True: - resp: Response = await self._async_resp_get() - if resp.req_id != req_id: - self._push_resp(req_id, resp) - else: - return resp - - def recv(self, req_id: int) -> Response: - """receive response of given request id. - - Different behavior in threadsafe mode. - """ - if not self.is_thread_safe(): - coro = self.async_recv(req_id) - return self.run_until_complete(coro) - - ret = self._pop_resp(req_id, default=None) - if ret is not None: - return ret - - # check resp que - while True: - resp: Response = self._resp_get() - if resp.req_id != req_id: - self._push_resp(req_id, resp) - else: - return resp + event = resp.event + while not event.is_set(): + try: + await asyncio.wait_for(event.wait(), 1) + except asyncio.TimeoutError: + if self.is_loop_alive(): + continue + logger.debug('Engine main loop failed.') + break + event.clear() + return resp + + def recv(self, resp: Response) -> Response: + """receive response of given request id.""" + coro = self.async_recv(resp) + return self.run_until_complete(coro) async def async_send(self, req_type: RequestType, data: Any): """send and receive synchronize.""" - req_id = await self.async_send_async(req_type, data) - return await self.async_recv(req_id) + resp = self.send_async(req_type, data) + return await self.async_recv(resp) def send(self, req_type: RequestType, data: Any) -> Response: """send and receive synchronize.""" - req_id = self.send_async(req_type, data) - return self.recv(req_id) + resp = self.send_async(req_type, data) + return self.recv(resp) def response_callback(self, resp: Response): """response callback.""" @@ -354,7 +178,7 @@ def response_callback(self, resp: Response): class RequestManager: """Request manager.""" - def __init__(self, thread_safe: bool = False): + def __init__(self): self.senders: Dict[int, RequestSender] = dict() self.callbacks: Dict[RequestType, Callable] = dict() self.request_priority: List[RequestType] = [ @@ -365,17 +189,7 @@ def __init__(self, thread_safe: bool = False): self.requests: asyncio.Queue = None self._loop_task: asyncio.Future = None self._loop_coro: Callable = None - self._thread_safe = thread_safe self._next_sender_id = 0 - self._mutex = Lock() - self._loop_thread: Thread = None - - self.thread_requests: Queue = None - # every sender has it's own responses, this responses is - # only used in thread safe mode. - self.responses: asyncio.Queue = None - if thread_safe: - self.thread_requests = Queue() def create_loop_task(self): """create coro task.""" @@ -385,7 +199,6 @@ def create_loop_task(self): 'Please set loop task with manager.start_loop') loop_unshielded = event_loop.create_task(self._loop_coro(), name='EngineMainLoop') - loop_unshielded.add_done_callback(_raise_exception_on_finish) self._loop_task = asyncio.shield(loop_unshielded) self.requests = asyncio.Queue() return self._loop_task @@ -398,105 +211,17 @@ def event_loop(self): else: return self._loop_task.get_loop() - def is_thread_safe(self): - """is thread safe.""" - return self._thread_safe - def start_loop(self, loop: asyncio.Task): """start main loop.""" self._loop_coro = loop - def __get_thread_reqs(): - """get thread reqs.""" - num_reqs = self.thread_requests.qsize() - reqs = [] - for _ in range(num_reqs): - tmp_reqs = self.thread_requests.get_nowait() - if isinstance(tmp_reqs, Request): - tmp_reqs = [tmp_reqs] - reqs += tmp_reqs - return reqs - - async def __async_get_req(event_loop): - """async get request.""" - que = self.thread_requests - not_empty = que.not_empty - with not_empty: - while not que._qsize(): - await event_loop.run_in_executor(None, not_empty.wait, 1.0) - reqs = que.get_nowait() - if isinstance(reqs, Request): - reqs = [reqs] - return reqs - - async def __req_loop(): - """req loop.""" - event_loop = asyncio.get_event_loop() - while True: - # get reqs - reqs = __get_thread_reqs() - if len(reqs) == 0: - reqs = await __async_get_req(event_loop) - self.requests.put_nowait(reqs) - - def __put_thread_resps(resps: List[Response]): - """put thread resps.""" - for resp in resps: - sender = self.senders.get(resp.sender_id, None) - if sender is None: - continue - sender.resp_thread_que.put_nowait(resp) - - async def __resp_loop(): - """resp loop.""" - while True: - num_resps = self.responses.qsize() - - if num_resps == 0: - resps = [await self.responses.get()] - else: - resps = [] - for _ in range(num_resps): - resps.append(self.responses.get_nowait()) - __put_thread_resps(resps) - await asyncio.sleep(0) - - def __run_forever(event_loop: asyncio.BaseEventLoop): - """run forever.""" - logger.debug('start thread run forever.') - asyncio.set_event_loop(event_loop) - self.responses = asyncio.Queue() - self.create_loop_task() - req_loop = event_loop.create_task(__req_loop(), - name='RunForeverReqLoop') - req_loop.add_done_callback(_ignore_exception_on_finish) - resp_loop = event_loop.create_task(__resp_loop(), - name='RunForeverRespLoop') - resp_loop.add_done_callback(_ignore_exception_on_finish) - self.event_loop.run_forever() - - if self.is_thread_safe(): - event_loop = asyncio.new_event_loop() - self._loop_thread = Thread(target=__run_forever, - args=(event_loop, ), - daemon=True) - self._loop_thread.start() + def stop_loop(self): + if self.is_loop_alive(): + self._loop_task.cancel() def is_loop_alive(self): """check if main loop is alive.""" - def __check_threadsafe(): - if self._loop_thread is None: - return False - if not self._loop_thread.is_alive(): - return False - if self._loop_task is None: - return False - return not self._loop_task.done() - - if self.is_thread_safe(): - return __check_threadsafe() - if self._loop_task is None: logger.debug('loop task has not been created.') return False @@ -508,12 +233,11 @@ def __check_threadsafe(): def build_sender(self): """create a new sender.""" - with self._mutex: - sender_id = self._next_sender_id - self._next_sender_id += 1 - new_sender = RequestSender.new(sender_id, self) - self.senders[sender_id] = new_sender - return new_sender + sender_id = self._next_sender_id + self._next_sender_id += 1 + new_sender = RequestSender.new(sender_id, self) + self.senders[sender_id] = new_sender + return new_sender def has_requests(self): """has unprocessed request.""" @@ -521,16 +245,27 @@ def has_requests(self): return False return not self.requests.empty() - def get_all_requests(self) -> Dict[RequestType, Request]: + async def get_all_requests(self) -> Dict[RequestType, Request]: """get all requests in current queue.""" num_reqs = self.requests.qsize() reqs: ReqList = [] - for _ in range(num_reqs): - elem = self.requests.get_nowait() + + def __proc_reqs(elem): + """proc reqs.""" + nonlocal reqs if isinstance(elem, Request): elem = [elem] reqs += elem + if num_reqs == 0: + elem = await self.requests.get() + __proc_reqs(elem) + num_reqs = self.requests.qsize() + + for _ in range(num_reqs): + elem = self.requests.get_nowait() + __proc_reqs(elem) + # gather requests reqs_by_type: Dict[RequestType, Request] = dict( (t, []) for t in RequestType) @@ -548,11 +283,7 @@ def set_request_priority(self, priority: List[RequestType]): def response(self, resp: Response): """send response.""" - if resp.sender_id not in self.senders: - logger.warning(f'sender {resp.sender_id} not exist. ' - f'Send {resp} failed.') - return - self.senders[resp.sender_id].response_callback(resp) + resp.event.set() def process_request(self, req_type: RequestType, reqs: ReqList, **kwargs): """process reqs with given req type.""" @@ -563,19 +294,18 @@ def process_request(self, req_type: RequestType, reqs: ReqList, **kwargs): else: # TODO: send error message for req in reqs: - resp = Response(ResponseType.HANDLER_NOT_EXIST, - sender_id=req.sender_id, - req_id=req.req_id, - err_msg=(f'callback for {req_type}' - ' not exists.')) + resp = req.resp + resp.type = ResponseType.HANDLER_NOT_EXIST + resp.err_msg = (f'callback for {req_type}' + ' not exists.') self.response(resp) - def step(self, **kwargs): + async def step(self, **kwargs): """handle requests. Should only be called in loop task. """ - reqs_by_type = self.get_all_requests() + reqs_by_type = await self.get_all_requests() # handle requests for req_type in self.request_priority: diff --git a/lmdeploy/pytorch/kernels/cuda/__init__.py b/lmdeploy/pytorch/kernels/cuda/__init__.py index 3790cf0f66..b62ddef80a 100644 --- a/lmdeploy/pytorch/kernels/cuda/__init__.py +++ b/lmdeploy/pytorch/kernels/cuda/__init__.py @@ -9,6 +9,7 @@ from .multinomial_sampling import multinomial_sampling from .pagedattention import paged_attention_fwd from .rms_norm import rms_norm +from .w8a8_fused_moe import fused_moe_w8a8 from .w8a8_triton_kernels import (matmul_kernel_dynamic_quant, per_channel_quant, per_token_quant_int8, rms_norm_dynamic_quant) @@ -28,4 +29,5 @@ 'rms_norm_dynamic_quant', 'flash_attention_fwd', 'flatten_kv_cache', + 'fused_moe_w8a8', ] diff --git a/lmdeploy/pytorch/kernels/cuda/awq_kernels.py b/lmdeploy/pytorch/kernels/cuda/awq_kernels.py index 13b9841e9b..395b0c427e 100644 --- a/lmdeploy/pytorch/kernels/cuda/awq_kernels.py +++ b/lmdeploy/pytorch/kernels/cuda/awq_kernels.py @@ -2,210 +2,95 @@ import triton from triton import language as tl -from .triton_utils import get_kernel_meta, wrap_jit_func - def get_cuda_autotune_config(): return [ - # most used - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 8 - }, - num_stages=4, - num_warps=4), - # # other - # triton.Config( - # { - # 'BLOCK_SIZE_M': 128, - # 'BLOCK_SIZE_N': 256, - # 'BLOCK_SIZE_K': 64, - # 'GROUP_SIZE_M': 8 - # }, - # num_stages=3, - # num_warps=8), - # triton.Config( - # { - # 'BLOCK_SIZE_M': 64, - # 'BLOCK_SIZE_N': 256, - # 'BLOCK_SIZE_K': 32, - # 'GROUP_SIZE_M': 8 - # }, - # num_stages=4, - # num_warps=4), - # triton.Config( - # { - # 'BLOCK_SIZE_M': 128, - # 'BLOCK_SIZE_N': 128, - # 'BLOCK_SIZE_K': 32, - # 'GROUP_SIZE_M': 8 - # }, - # num_stages=4, - # num_warps=4), - # triton.Config( - # { - # 'BLOCK_SIZE_M': 64, - # 'BLOCK_SIZE_N': 128, - # 'BLOCK_SIZE_K': 32, - # 'GROUP_SIZE_M': 8 - # }, - # num_stages=4, - # num_warps=4), - # triton.Config( - # { - # 'BLOCK_SIZE_M': 128, - # 'BLOCK_SIZE_N': 32, - # 'BLOCK_SIZE_K': 32, - # 'GROUP_SIZE_M': 8 - # }, - # num_stages=4, - # num_warps=4), - # triton.Config( - # { - # 'BLOCK_SIZE_M': 64, - # 'BLOCK_SIZE_N': 32, - # 'BLOCK_SIZE_K': 32, - # 'GROUP_SIZE_M': 8 - # }, - # num_stages=5, - # num_warps=2), - # triton.Config( - # { - # 'BLOCK_SIZE_M': 32, - # 'BLOCK_SIZE_N': 64, - # 'BLOCK_SIZE_K': 32, - # 'GROUP_SIZE_M': 8 - # }, - # num_stages=5, - # num_warps=2), - # # Good config for fp8 inputs. - # triton.Config( - # { - # 'BLOCK_SIZE_M': 128, - # 'BLOCK_SIZE_N': 256, - # 'BLOCK_SIZE_K': 128, - # 'GROUP_SIZE_M': 8 - # }, - # num_stages=3, - # num_warps=8), - # triton.Config( - # { - # 'BLOCK_SIZE_M': 256, - # 'BLOCK_SIZE_N': 128, - # 'BLOCK_SIZE_K': 128, - # 'GROUP_SIZE_M': 8 - # }, - # num_stages=3, - # num_warps=8), - # triton.Config( - # { - # 'BLOCK_SIZE_M': 256, - # 'BLOCK_SIZE_N': 64, - # 'BLOCK_SIZE_K': 128, - # 'GROUP_SIZE_M': 8 - # }, - # num_stages=4, - # num_warps=4), - # triton.Config( - # { - # 'BLOCK_SIZE_M': 64, - # 'BLOCK_SIZE_N': 256, - # 'BLOCK_SIZE_K': 128, - # 'GROUP_SIZE_M': 8 - # }, - # num_stages=4, - # num_warps=4), - # triton.Config( - # { - # 'BLOCK_SIZE_M': 128, - # 'BLOCK_SIZE_N': 128, - # 'BLOCK_SIZE_K': 128, - # 'GROUP_SIZE_M': 8 - # }, - # num_stages=4, - # num_warps=4), - # triton.Config( - # { - # 'BLOCK_SIZE_M': 128, - # 'BLOCK_SIZE_N': 64, - # 'BLOCK_SIZE_K': 64, - # 'GROUP_SIZE_M': 8 - # }, - # num_stages=4, - # num_warps=4), - # triton.Config( - # { - # 'BLOCK_SIZE_M': 128, - # 'BLOCK_SIZE_N': 32, - # 'BLOCK_SIZE_K': 64, - # 'GROUP_SIZE_M': 8 - # }, - # num_stages=4, - # num_warps=4), + triton.Config({ + 'BLOCK_SIZE_N': 64, + 'GROUP_SIZE_M': 8, + }, + num_stages=3, + num_warps=4), ] @triton.jit -def _get_unpacked_order(offs_n, elem_per_int: tl.constexpr): - """get unpacked order.""" - origin_order = offs_n % elem_per_int - unpacked_order = (origin_order & 1) * 4 + origin_order // 2 - return unpacked_order +def _dequant_s4_to_f16x2(weight, shift: tl.constexpr, is_top: tl.constexpr): + + immLut: tl.constexpr = (0xf0 & 0xcc) | 0xaa + BOTTOM_MASK: tl.constexpr = 0x000f000f + TOP_MASK: tl.constexpr = 0x00f000f0 + I4s_TO_F16s_MAGIC_NUM: tl.constexpr = 0x64006400 + FP16_TOP_MAGIC_NUM: tl.constexpr = 0x64006400 + ONE_SIXTEENTH: tl.constexpr = 0x2c002c00 + NEG_64: tl.constexpr = 0xd400d400 + + if shift: + weight = weight >> 8 + + if is_top: + return tl.inline_asm_elementwise("""{ + .reg .b32 tmp; + lop3.b32 tmp, $2, $3, $4, $5; + fma.rn.f16x2 tmp, tmp, $6, $7; + mov.b32 {$0, $1}, tmp; + }""", + '=h,=h,r,n,n,n,r,r', + args=[ + weight, TOP_MASK, + I4s_TO_F16s_MAGIC_NUM, immLut, + ONE_SIXTEENTH, NEG_64 + ], + dtype=(tl.float16, tl.float16), + is_pure=True, + pack=1) + else: + return tl.inline_asm_elementwise("""{ + .reg .b32 tmp; + lop3.b32 tmp, $2, $3, $4, $5; + sub.f16x2 tmp, tmp, $6; + mov.b32 {$0, $1}, tmp; + }""", + '=h,=h,r,n,n,n,r', + args=[ + weight, BOTTOM_MASK, + I4s_TO_F16s_MAGIC_NUM, immLut, + FP16_TOP_MAGIC_NUM + ], + dtype=(tl.float16, tl.float16), + is_pure=True, + pack=1) @triton.jit -def _broadcast_pack(weight, width: tl.constexpr): - """broadcast pack.""" - broadcast_tmp = tl.arange(0, width) +def _unpack_weight(weight): + """unpack weight.""" + # broadcast and shift + width: tl.constexpr = 8 BLOCK_SIZE_K: tl.constexpr = weight.shape[0] BLOCK_SIZE_QN: tl.constexpr = weight.shape[1] BLOCK_SIZE_N: tl.constexpr = BLOCK_SIZE_QN * width - weight = tl.broadcast(weight[:, :, None], broadcast_tmp[None, None, :]) - weight = tl.reshape(weight, (BLOCK_SIZE_K, BLOCK_SIZE_N)) - return weight + w0, w1 = _dequant_s4_to_f16x2(weight, False, False) + w2, w3 = _dequant_s4_to_f16x2(weight, False, True) + w4, w5 = _dequant_s4_to_f16x2(weight, True, False) + w6, w7 = _dequant_s4_to_f16x2(weight, True, True) -@triton.jit -def _unpack_weight(weight, order): - """unpack weight.""" - weight = _broadcast_pack(weight, 8) - weight = weight >> (order * 4) - # cast to float16 - immLut = (0xf0 & 0xcc) | 0xaa - BOTTOM_MASK = 0xf - I4s_TO_F16s_MAGIC_NUM = 0x6400 - FP16_TOP_MAGIC_NUM = 0x6400 - weight = tl.inline_asm_elementwise( - """lop3.b32 $1, $1, $2, $3, $4; - sub.f16x2 $1, $1, $5; - mov.b32 {$0, _}, $1;""", - '=h, r, n, n, n, r', [ - weight, BOTTOM_MASK, I4s_TO_F16s_MAGIC_NUM, immLut, - FP16_TOP_MAGIC_NUM - ], - dtype=tl.float16, - is_pure=False, - pack=1) - return weight + w04 = tl.join(w0, w4) + w15 = tl.join(w1, w5) + w26 = tl.join(w2, w6) + w37 = tl.join(w3, w7) + w0246 = tl.join(w04, w26) + w1357 = tl.join(w15, w37) + weight = tl.join(w0246, w1357) + + return weight.reshape(BLOCK_SIZE_K, BLOCK_SIZE_N) @triton.autotune( configs=get_cuda_autotune_config(), - key=['M_NEXT_P2', 'N', 'K'], + key=['N', 'K'], ) -@wrap_jit_func @triton.jit def awq_linear_kernel( a_ptr, @@ -225,12 +110,9 @@ def awq_linear_kernel( stride_zk: tl.constexpr, stride_zn: tl.constexpr, # stride_cm, - stride_ck: tl.constexpr, stride_cn: tl.constexpr, # Meta-parameters - M_NEXT_P2: tl.constexpr, - Q_GROUP_SIZE: tl.constexpr, - SPLIT_K_ITERS: tl.constexpr, + SPLIT_K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # @@ -239,19 +121,13 @@ def awq_linear_kernel( """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """ - ELEM_PER_INT = 8 - if Q_GROUP_SIZE > BLOCK_SIZE_K: - GROUP_SIZE_K: tl.constexpr = BLOCK_SIZE_K - else: - GROUP_SIZE_K: tl.constexpr = Q_GROUP_SIZE - K_PER_GROUP: tl.constexpr = Q_GROUP_SIZE // GROUP_SIZE_K # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse. # See above `L2 Cache Optimizations` section for details. + kid = tl.program_id(axis=1) pid = tl.program_id(axis=0) - split_kid = tl.program_id(axis=1) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n @@ -267,8 +143,7 @@ def awq_linear_kernel( offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N BLOCK_SIZE_QN: tl.constexpr = BLOCK_SIZE_N // 8 offs_wn = pid_n * BLOCK_SIZE_QN + tl.arange(0, BLOCK_SIZE_QN) - offs_k = tl.arange(0, GROUP_SIZE_K) - unpacked_order = _get_unpacked_order(offs_bn, ELEM_PER_INT) + offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) qw_ptrs = qw_ptr + (offs_k[:, None] * stride_wk + @@ -276,49 +151,52 @@ def awq_linear_kernel( s_ptrs = s_ptr + offs_bn * stride_sn qz_ptrs = qz_ptr + offs_wn * stride_zn - # split k - NUM_K_BLOCKS = K // GROUP_SIZE_K - K_PER_SPLIT = tl.cdiv(NUM_K_BLOCKS, SPLIT_K_ITERS) - k_start = split_kid * K_PER_SPLIT - k_last = min(k_start + K_PER_SPLIT, NUM_K_BLOCKS) - a_ptrs += k_start * GROUP_SIZE_K * stride_ak - qw_ptrs += k_start * GROUP_SIZE_K * stride_wk - qg_id = k_start // K_PER_GROUP - # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - s = tl.zeros((1, BLOCK_SIZE_N), dtype=s_ptrs.dtype.element_ty) - zs = tl.zeros((1, BLOCK_SIZE_N), dtype=s_ptrs.dtype.element_ty) + + k_start = kid + k_last = K // BLOCK_SIZE_K # prefetch - next_qw = tl.load(qw_ptrs) - qw_ptrs += GROUP_SIZE_K * stride_wk + a_ptrs += k_start * BLOCK_SIZE_K * stride_ak + qw_ptrs += k_start * BLOCK_SIZE_K * stride_wk + s_ptrs += k_start * stride_sk + qz_ptrs += k_start * stride_zk + qw = tl.load(qw_ptrs) + qz = tl.load(qz_ptrs)[None, :] + s = tl.load(s_ptrs)[None, :] + qw_ptrs += SPLIT_K * BLOCK_SIZE_K * stride_wk + s_ptrs += SPLIT_K * stride_sk + qz_ptrs += SPLIT_K * stride_zk + + for k in tl.range(k_start, k_last, SPLIT_K, num_stages=3): - for k in range(k_start, k_last): + # load a a = tl.load(a_ptrs) - qw = next_qw - if k + 1 < k_last: - next_qw = tl.load(qw_ptrs) - w = _unpack_weight(qw, unpacked_order) - - if k == k_start or k % K_PER_GROUP == 0: - s = tl.load(s_ptrs + qg_id * stride_sk)[None, :] - qz = tl.load(qz_ptrs + qg_id * stride_zk)[None, :] - qg_id += 1 - z = _unpack_weight(qz, unpacked_order) - zs = -z * s - b = w * s + zs + + # unpack b + z = _unpack_weight(qz) + w = _unpack_weight(qw) + b = (w - z) * s + + # load next q + mask = k + SPLIT_K < k_last + qz = tl.load(qz_ptrs, mask=mask)[None, :] + s = tl.load(s_ptrs, mask=mask)[None, :] + qw = tl.load(qw_ptrs, mask=mask) # We accumulate along the K dimension. - accumulator += tl.dot(a, b) + accumulator = tl.dot(a, b, acc=accumulator) # Advance the ptrs to the next K block. - a_ptrs += GROUP_SIZE_K * stride_ak - qw_ptrs += GROUP_SIZE_K * stride_wk + a_ptrs += SPLIT_K * BLOCK_SIZE_K * stride_ak + qw_ptrs += SPLIT_K * BLOCK_SIZE_K * stride_wk + s_ptrs += SPLIT_K * stride_sk + qz_ptrs += SPLIT_K * stride_zk c = accumulator.to(tl.float16) @@ -329,11 +207,11 @@ def awq_linear_kernel( c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - if stride_ck > 0: - c_ptrs += split_kid * stride_ck - tl.store(c_ptrs, c, mask=c_mask) + + if SPLIT_K > 1: + tl.atomic_add(c_ptrs, c, mask=c_mask, sem='relaxed', scope='gpu') else: - tl.atomic_add(c_ptrs, c, mask=c_mask) + tl.store(c_ptrs, c, mask=c_mask) def awq_linear(x, qweight, scales, qzeros): @@ -341,18 +219,24 @@ def awq_linear(x, qweight, scales, qzeros): M = x.size(0) K = qweight.size(0) N = scales.size(1) - SPLIT_K_ITERS = 4 group_size = K // scales.size(0) + SPLIT_K = max(1, K // 4096) def grid(META): """grid.""" - return (triton.cdiv(M, META['BLOCK_SIZE_M']) * - triton.cdiv(N, META['BLOCK_SIZE_N']), SPLIT_K_ITERS) + return ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * + triton.cdiv(N, META['BLOCK_SIZE_N']), + SPLIT_K, + ) - out = scales.new_empty(M, SPLIT_K_ITERS, N) - M_NEXT_P2 = triton.next_power_of_2(M) + if SPLIT_K > 1: + out = scales.new_zeros(M, N) + else: + out = scales.new_empty(M, N) - kernel_meta = get_kernel_meta(x) + BLOCK_SIZE_M = triton.next_power_of_2(M) + BLOCK_SIZE_M = max(16, min(128, BLOCK_SIZE_M)) awq_linear_kernel[grid]( # Pointers to matrices x, @@ -373,12 +257,11 @@ def grid(META): stride_zk=qzeros.stride(0), stride_zn=qzeros.stride(1), # stride_cm=out.stride(0), - stride_ck=out.stride(1), - stride_cn=out.stride(2), + stride_cn=out.stride(1), # Meta-parameters - M_NEXT_P2=M_NEXT_P2, - Q_GROUP_SIZE=group_size, - SPLIT_K_ITERS=SPLIT_K_ITERS, - **kernel_meta) + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_K=group_size, + SPLIT_K=SPLIT_K, + ) - return out.sum(1) + return out diff --git a/lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py b/lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py new file mode 100644 index 0000000000..4907d92ac5 --- /dev/null +++ b/lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py @@ -0,0 +1,344 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# modify from: https://github.com/vllm-project/vllm +import torch +import triton +import triton.language as tl + +from .activation import silu_and_mul +from .blocked_gemm_fp8 import quant_fp8 +from .fused_moe import _get_sorted_idx, _make_intermediate, _renormalize + + +def get_cuda_autotune_config(): + return [ + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + }, + num_stages=4, + num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + }, + num_stages=4, + num_warps=4), + ] + + +@triton.autotune( + configs=get_cuda_autotune_config(), + key=['N', 'K', 'M_NP2'], + warmup=10, + rep=25, +) +@triton.jit +def fused_moe_blocked_f8_kernel( + A, + A_scale, + B, + B_scale, + C, + SortedIdx, + ExpStart, + ExpEnd, + Weights, + N: tl.constexpr, + K: tl.constexpr, + group_ak: tl.constexpr, + group_bk: tl.constexpr, + group_bn: tl.constexpr, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_asm, + stride_ask: tl.constexpr, + stride_be: tl.constexpr, + stride_bn: tl.constexpr, + stride_bk: tl.constexpr, + stride_bse: tl.constexpr, + stride_bsk: tl.constexpr, + stride_bsn: tl.constexpr, + stride_cm, + stride_cn: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + M_NP2: tl.constexpr, + ENABLE_WEIGHTS: tl.constexpr, + top_k: tl.constexpr, + expert_offset: tl.constexpr, + reindex_a: tl.constexpr, + reindex_c: tl.constexpr, +): + """fused moe kernel.""" + exp_id = tl.program_id(1) + pid = tl.program_id(0) + + exp_start = tl.load(ExpStart + exp_id + expert_offset) + exp_end = tl.load(ExpEnd + exp_id + expert_offset) + M = exp_end - exp_start + if M <= 0: + return + + num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + if GROUP_SIZE_M == 1: + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N: + return + + offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask_sid = offs_sid < exp_end + sid = tl.load(SortedIdx + offs_sid, mask=mask_sid, other=0) + + offs_k = tl.arange(0, BLOCK_SIZE_K) + if reindex_a: + offs_am = sid // top_k + else: + offs_am = offs_sid + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + as_ptrs = A_scale + offs_am + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), + BLOCK_SIZE_N) + + # deepseek has 160 experts, exp index would overflow int32 + exp_id = exp_id.to(tl.int64) + exp_off = stride_be * exp_id + b_ptrs = B + exp_off + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + + offs_bsn = pid_n * BLOCK_SIZE_N // group_bn + as_ptrs = A_scale + offs_am * stride_asm + bs_ptrs = B_scale + stride_bse * exp_id + offs_bsn * stride_bsn + + acc_scale = tl.load(as_ptrs) * tl.load(bs_ptrs) + acc_ratio = 1 / acc_scale + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # load scales + k_start = (k + 1) * BLOCK_SIZE_K + offs_ksa = k_start // group_ak + offs_ksb = k_start // group_bk + a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, + mask=k_start < K, + other=1.0) + b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, + mask=k_start < K, + other=1.0) + + # load ab + a = tl.load(a_ptrs, + mask=mask_sid[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + + # mma + accumulator = tl.dot(a, b, acc=accumulator * acc_ratio[:, None]) + + # update scales and ratio + new_acc_scale = a_scale * b_scale + acc_ratio = acc_scale / new_acc_scale + acc_scale = new_acc_scale + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator * (acc_ratio * acc_scale)[:, None] + + if ENABLE_WEIGHTS: + weight = tl.load(Weights + sid, mask=mask_sid) + c = c * weight[:, None].to(c.dtype) + + c = c.to(C.dtype.element_ty) + + if reindex_c: + offs_cm = sid + else: + offs_cm = offs_sid + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_bn[None, :] + tl.store(c_ptrs, c, mask=mask_sid[:, None]) + + +def fused_moe_blocked_fp8_kernel_launcher( + A: torch.Tensor, + A_scale: torch.Tensor, + B: torch.Tensor, + B_scale: torch.Tensor, + C: torch.Tensor, + sorted_idx: torch.Tensor, + exp_start: torch.Tensor, + exp_end: torch.Tensor, + weights: torch.Tensor, + enable_weights: bool = False, + top_k: int = 1, + num_tokens: int = None, + expert_offset: int = 0, + reindex_a: bool = True, + reindex_c: bool = True, +): + """fused moe kernel launcher.""" + + if num_tokens is None: + num_tokens = A.size(0) + M_NP2 = triton.next_power_of_2(num_tokens) + M_NP2 = max(64, M_NP2) + E, N, K = B.shape + + assert A.dim() == 2 + assert A_scale.dim() == 2 + assert B.dim() == 3 + assert B_scale.dim() == 3 + + assert K % A_scale.size(1) == 0 + assert K % B_scale.size(2) == 0 + assert N % B_scale.size(1) == 0 + + group_ak = K // A_scale.size(1) + group_bk = K // B_scale.size(2) + group_bn = N // B_scale.size(1) + + def _grid_fn(META): + grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) * + triton.cdiv(N, META['BLOCK_SIZE_N']), E) + return grid + + A = A.flatten(0, -2) + C = C.flatten(0, -2) + + BLOCK_SIZE_K = group_bk + GROUP_SIZE_M = 8 + grid = _grid_fn + fused_moe_blocked_f8_kernel[grid]( + A, + A_scale, + B, + B_scale, + C, + sorted_idx, + exp_start, + exp_end, + weights, + N=N, + K=K, + group_ak=group_ak, + group_bk=group_bk, + group_bn=group_bn, + stride_am=A.stride(0), + stride_ak=A.stride(1), + stride_asm=A_scale.stride(0), + stride_ask=A_scale.stride(1), + stride_be=B.stride(0), + stride_bn=B.stride(1), + stride_bk=B.stride(2), + stride_bse=B_scale.stride(0), + stride_bsn=B_scale.stride(1), + stride_bsk=B_scale.stride(2), + stride_cm=C.stride(0), + stride_cn=C.stride(1), + ENABLE_WEIGHTS=enable_weights, + top_k=top_k, + expert_offset=expert_offset, + reindex_a=reindex_a, + reindex_c=reindex_c, + M_NP2=M_NP2, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + ) + + +def fused_moe_blocked_fp8(input: torch.Tensor, + input_scale: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + topk: int, + out_dtype: torch.dtype = torch.float16, + expert_offset: int = 0, + num_experts: int = None, + renormalize: bool = False) -> torch.Tensor: + """fused moe.""" + device = input.device + M = input.size(0) + E, N, _ = w1.shape + if num_experts is None: + num_experts = E + full_exp = num_experts == E + group_size = input.size(-1) // input_scale.size(-1) + + topk_weights = _renormalize(topk_weights, renormalize) + sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts) + + intermediate_cache1 = _make_intermediate((M, topk, N), + dtype=out_dtype, + device=device, + zeros=not full_exp) + # gate and up + fused_moe_blocked_fp8_kernel_launcher( + input, + input_scale, + w1, + w1_scale, + intermediate_cache1, + sorted_idx=sorted_idx, + exp_start=exp_start, + exp_end=exp_end, + weights=topk_weights, + enable_weights=False, + top_k=topk, + num_tokens=M, + expert_offset=expert_offset, + reindex_a=True, + reindex_c=False, + ) + + # activate + intermediate_cache1 = intermediate_cache1.flatten(0, -2) + gate_cache = silu_and_mul(intermediate_cache1) + del intermediate_cache1 + gate_cache, gate_scale = quant_fp8(gate_cache, + group_size, + dtype=input.dtype) + + intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]), + dtype=out_dtype, + device=device, + zeros=not full_exp) + # down + fused_moe_blocked_fp8_kernel_launcher( + gate_cache, + gate_scale, + w2, + w2_scale, + intermediate_cache2, + sorted_idx=sorted_idx, + exp_start=exp_start, + exp_end=exp_end, + weights=topk_weights, + enable_weights=True, + top_k=1, + num_tokens=M, + expert_offset=expert_offset, + reindex_a=False, + reindex_c=True, + ) + + ret = intermediate_cache2.sum(dim=1) + return ret diff --git a/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py b/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py new file mode 100644 index 0000000000..9f992bcfef --- /dev/null +++ b/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py @@ -0,0 +1,237 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import triton +import triton.language as tl +from torch import Tensor + + +@triton.jit +def _quant_fp8_kernel( + a_ptr, + out_ptr, + scale_ptr, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + stride_am, + stride_ak: tl.constexpr, + stride_om, + stride_ok: tl.constexpr, + stride_sm, + stride_sg: tl.constexpr, + GROUP_SIZE: tl.constexpr, +): + """quant fp8 kernel.""" + group_id = tl.program_id(0) + m_id = tl.program_id(1) + + g_offs = group_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE) + + a_ptrs = a_ptr + m_id * stride_am + g_offs * stride_ak + o_ptrs = out_ptr + m_id * stride_om + g_offs * stride_ok + s_ptr = scale_ptr + m_id * stride_sm + group_id * stride_sg + + rfp8_max = 1 / fp8_max + + a = tl.load(a_ptrs).to(tl.float32) + scale = tl.max(tl.abs(a)) * rfp8_max + out = a / scale + + out = tl.clamp(out, fp8_min, fp8_max) + out = out.to(out_ptr.dtype.element_ty) + + tl.store(o_ptrs, out) + tl.store(s_ptr, scale) + + +def quant_fp8(A: Tensor, + group_size: int, + dtype: torch.dtype = torch.float8_e4m3fn): + """quant online.""" + assert A.dim() == 2 + M, K = A.shape + assert K % group_size == 0 + num_groups = K // group_size + + finfo = torch.finfo(dtype) + fmin = finfo.min + fmax = finfo.max + + out = torch.empty_like(A, dtype=dtype) + scales = A.new_empty(M, num_groups, dtype=torch.float32) + grid = (num_groups, M) + num_warps = 4 + num_stages = 1 + _quant_fp8_kernel[grid]( + A, + out, + scales, + fp8_min=fmin, + fp8_max=fmax, + stride_am=A.stride(0), + stride_ak=A.stride(1), + stride_om=out.stride(0), + stride_ok=out.stride(1), + stride_sm=scales.stride(0), + stride_sg=scales.stride(1), + GROUP_SIZE=group_size, + num_warps=num_warps, + num_stages=num_stages, + ) + + return out, scales + + +@triton.autotune(configs=[ + triton.Config({ + 'BLOCK_M': 64, + 'BLOCK_N': 128, + }, num_stages=3, num_warps=4), + triton.Config({ + 'BLOCK_M': 128, + 'BLOCK_N': 64, + }, num_stages=3, num_warps=4) +], + key=['N', 'K'], + warmup=5, + rep=10) +@triton.jit +def _gemm_fp8_kernel( + A, + a_scale_ptr, + B, + b_scale_ptr, + C, + M, + N: tl.constexpr, + K: tl.constexpr, + group_ak: tl.constexpr, + group_bk: tl.constexpr, + group_bn: tl.constexpr, + stride_am, + stride_ak: tl.constexpr, + stride_asm, + stride_ask: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_bsk: tl.constexpr, + stride_bsn: tl.constexpr, + stride_cm, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + """gemm fp8 kernel.""" + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + offs_k = tl.arange(0, BLOCK_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + offs_bsn = pid_n * BLOCK_N // group_bn + as_ptrs = a_scale_ptr + offs_am * stride_asm + bs_ptrs = b_scale_ptr + offs_bsn * stride_bsn + + acc_scale = tl.load(as_ptrs) * tl.load(bs_ptrs) + acc_ratio = 1 / acc_scale + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + # load scales + k_start = (k + 1) * BLOCK_K + offs_ksa = k_start // group_ak + offs_ksb = k_start // group_bk + a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, + mask=k_start < K, + other=1.0) + b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, + mask=k_start < K, + other=1.0) + + # load ab + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + + # mma + accumulator = tl.dot(a, b, acc=accumulator * acc_ratio[:, None]) + + # update scales and ratio + new_acc_scale = a_scale * b_scale + acc_ratio = acc_scale / new_acc_scale + acc_scale = new_acc_scale + + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + c = accumulator * (acc_ratio * acc_scale)[:, None] + + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def blocked_gemm_fp8(A: Tensor, + A_scale: Tensor, + B: Tensor, + B_scale: torch.Tensor, + out_dtype: torch.dtype = torch.float16): + """gemm fp8.""" + + def grid(META): + return (triton.cdiv(M, META['BLOCK_M']) * + triton.cdiv(N, META['BLOCK_N']), ) + + assert A.dim() == 2 + assert A_scale.dim() == 2 + assert B.dim() == 2 + assert B_scale.dim() == 2 + + M, K = A.shape + _, N = B.shape + + group_ak = triton.cdiv(K, A_scale.size(1)) + group_bk = triton.cdiv(K, B_scale.size(0)) + group_bn = triton.cdiv(N, B_scale.size(1)) + + C = A.new_empty(M, N, dtype=out_dtype) + + BLOCK_K = max(group_ak, group_bk) + + _gemm_fp8_kernel[grid]( + A, + A_scale, + B, + B_scale, + C, + M=M, + N=N, + K=K, + group_ak=group_ak, + group_bk=group_bk, + group_bn=group_bn, + stride_am=A.stride(0), + stride_ak=A.stride(1), + stride_asm=A_scale.stride(0), + stride_ask=A_scale.stride(1), + stride_bk=B.stride(0), + stride_bn=B.stride(1), + stride_bsk=B_scale.stride(0), + stride_bsn=B_scale.stride(1), + stride_cm=C.stride(0), + stride_cn=C.stride(1), + BLOCK_K=BLOCK_K, + GROUP_M=8, + ) + + return C diff --git a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py index e2b2091b84..5f59ac4651 100644 --- a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py +++ b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py @@ -31,7 +31,7 @@ def _flatten_kv_cache( stride_vos: tl.constexpr, stride_vod: tl.constexpr, stride_boff, - OUT_SIZE: tl.constexpr, + OUT_SIZE, HEAD_DIM_K: tl.constexpr, HEAD_DIM_V: tl.constexpr, BLOCK_BS: tl.constexpr, @@ -125,7 +125,7 @@ def _flatten_kv_cache_quant( stride_vod: tl.constexpr, stride_boff, quant_policy: tl.constexpr, - OUT_SIZE: tl.constexpr, + OUT_SIZE, HEAD_DIM_K: tl.constexpr, HEAD_DIM_V: tl.constexpr, BLOCK_BS: tl.constexpr, diff --git a/lmdeploy/pytorch/kernels/cuda/fused_lora.py b/lmdeploy/pytorch/kernels/cuda/fused_lora.py index d7fbb34588..3dc7e3a10b 100644 --- a/lmdeploy/pytorch/kernels/cuda/fused_lora.py +++ b/lmdeploy/pytorch/kernels/cuda/fused_lora.py @@ -9,8 +9,8 @@ def get_autotune_config(): return [ triton.Config( { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128 }, num_stages=4, @@ -26,9 +26,26 @@ def get_autotune_config(): ] +@triton.jit +def _atomic_store(ptrs, val, mask): + """atomic store values.""" + dtype = ptrs.dtype.element_ty + if (dtype == torch.float16) | (dtype == torch.float32): + tl.atomic_add(ptrs, val, mask=mask, sem='relaxed') + else: + # bfloat16 does not support atomic add + origin = tl.load(ptrs, mask=mask) + val = val.to(origin.dtype) + val += origin + tl.store(ptrs, val, mask=mask) + + @triton.autotune( configs=get_autotune_config(), key=['N', 'K'], + restore_value=['c_ptr'], + warmup=5, + rep=20, ) @triton.jit def _fused_lora_kernel( @@ -44,18 +61,19 @@ def _fused_lora_kernel( adapter_ids_ptr, N: tl.constexpr, K: tl.constexpr, - stride_am: tl.constexpr, + stride_am, stride_ak: tl.constexpr, stride_lar: tl.constexpr, stride_lak: tl.constexpr, stride_lbr: tl.constexpr, stride_lbn: tl.constexpr, - stride_cm: tl.constexpr, + stride_cm, stride_cn: tl.constexpr, BLOCK_SIZE_R: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + CUM: tl.constexpr, ): """fused lora kernel.""" pid = tl.program_id(axis=0) @@ -70,87 +88,91 @@ def _fused_lora_kernel( rank_start = tl.load(rank_start_ptr + adapter_id) rank = tl.load(ranks_ptr + adapter_id) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - GROUP_SIZE_M: tl.constexpr = 1 - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m + pid_m = pid if pid_m * BLOCK_SIZE_M >= M: return offs_m = (seq_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) - offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + offs_n = tl.arange(0, BLOCK_SIZE_N) mask_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) < M - if rank == 0: - offs_cm = offs_m - offs_cn = offs_n - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[ - None, :] - c_mask = mask_cm[:, None] & (offs_cn[None, :] < N) - tl.store(c_ptrs, 0, mask=c_mask) - return - - offs_am = (seq_start + - (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M) - offs_r = rank_start + tl.arange(0, BLOCK_SIZE_R) % rank - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + - offs_k[None, :] * stride_ak) - la_ptrs = lora_a_ptr + (offs_k[:, None] * stride_lak + - offs_r[None, :] * stride_lar) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_R), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - # Load the next block of A and B - # If it is out of bounds, set it to 0. - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - la = tl.load(la_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0) - # We accumulate along the K dimension. - accumulator += tl.dot(a, la) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - la_ptrs += BLOCK_SIZE_K * stride_lak - ar = accumulator.to(lora_b_ptr.dtype.element_ty) - - offs_lbn = offs_n % N - lb_ptrs = lora_b_ptr + (offs_r[:, None] * stride_lbr + - offs_lbn * stride_lbn) - lb = tl.load(lb_ptrs, mask=tl.arange(0, BLOCK_SIZE_R)[:, None] < rank) - - c = tl.dot(ar, lb) - - scaling = tl.load(scaling_ptr + adapter_id) - c *= scaling - - c = c.to(c_ptr.dtype.element_ty) offs_cm = offs_m - offs_cn = offs_n - c_ptrs = c_ptr + stride_cm * offs_cm[:, - None] + stride_cn * offs_cn[None, :] - c_mask = mask_cm[:, None] & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) - - -def fused_lora(input: torch.Tensor, lora_a: torch.Tensor, lora_b: torch.Tensor, - scaling: torch.LongTensor, rank_start: torch.LongTensor, - ranks: torch.LongTensor, seq_start: torch.LongTensor, - seq_lens: torch.LongTensor, adapter_ids: torch.LongTensor, - max_rank: int, max_seqlen: int): + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_n[None, :] + + if rank == 0: + if not CUM: + for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)): + mask_cn = (offs_n < N - n * BLOCK_SIZE_N) + c_mask = mask_cm[:, None] * mask_cn[None, :] + tl.store(c_ptrs, 0.0, mask=c_mask) + c_ptrs += stride_cn * BLOCK_SIZE_N + else: + + offs_am = (seq_start + + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M) + offs_r = rank_start + tl.arange(0, BLOCK_SIZE_R) % rank + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + + offs_k[None, :] * stride_ak) + la_ptrs = lora_a_ptr + (offs_k[:, None] * stride_lak + + offs_r[None, :] * stride_lar) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_R), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + la = tl.load(la_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + # We accumulate along the K dimension. + accumulator = tl.dot(a, la, acc=accumulator) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + la_ptrs += BLOCK_SIZE_K * stride_lak + ar = accumulator.to(lora_b_ptr.dtype.element_ty) + + scaling = tl.load(scaling_ptr + adapter_id).to(ar.dtype) + ar *= scaling + ar = tl.where( + tl.arange(0, BLOCK_SIZE_R)[None, :] < rank, ar, tl.zeros_like(ar)) + lb_ptrs = lora_b_ptr + (offs_r[:, None] * stride_lbr + + offs_n[None, :] * stride_lbn) + + for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)): + lb = tl.load(lb_ptrs, mask=offs_n[None, :] < N - n * BLOCK_SIZE_N) + c = tl.dot(ar, lb) + + mask_cn = (offs_n < N - n * BLOCK_SIZE_N) + c_mask = mask_cm[:, None] * mask_cn[None, :] + if CUM: + _atomic_store(c_ptrs, c, mask=c_mask) + else: + tl.store(c_ptrs, c, mask=c_mask) + c_ptrs += stride_cn * BLOCK_SIZE_N + lb_ptrs += stride_lbn * BLOCK_SIZE_N + + +def fused_lora(input: torch.Tensor, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + scaling: torch.LongTensor, + rank_start: torch.LongTensor, + ranks: torch.LongTensor, + seq_start: torch.LongTensor, + seq_lens: torch.LongTensor, + adapter_ids: torch.LongTensor, + max_rank: int, + max_seqlen: int, + output: torch.Tensor = None, + cum: bool = False): """fused lora.""" def grid(META): - ret = ((triton.cdiv(max_seqlen, META['BLOCK_SIZE_M']) * - triton.cdiv(N, META['BLOCK_SIZE_N'])), batch_size) + ret = ((triton.cdiv(max_seqlen, META['BLOCK_SIZE_M'])), batch_size) return ret assert input.dim() == 2 @@ -158,7 +180,12 @@ def grid(META): M, K = input.shape N = lora_b.size(1) - output = input.new_empty((M, N)) + if output is None: + output = input.new_empty((M, N)) + cum = False + else: + assert output.size(0) == M + assert output.size(1) == N BLOCK_SIZE_R = max(16, max_rank) _fused_lora_kernel[grid]( @@ -183,6 +210,7 @@ def grid(META): stride_cm=output.stride(0), stride_cn=output.stride(1), BLOCK_SIZE_R=BLOCK_SIZE_R, + CUM=cum, ) return output diff --git a/lmdeploy/pytorch/kernels/cuda/fused_moe.py b/lmdeploy/pytorch/kernels/cuda/fused_moe.py index 9f9771368e..9d73208c53 100644 --- a/lmdeploy/pytorch/kernels/cuda/fused_moe.py +++ b/lmdeploy/pytorch/kernels/cuda/fused_moe.py @@ -91,8 +91,6 @@ def fused_moe_kernel( if GROUP_SIZE_M == 1: pid_m = pid % num_pid_m pid_n = pid // num_pid_m - # pid_m = pid // num_pid_n - # pid_n = pid % num_pid_n else: num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group @@ -133,7 +131,7 @@ def fused_moe_kernel( b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator += tl.dot(a, b) + accumulator = tl.dot(a, b, acc=accumulator) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk @@ -271,6 +269,33 @@ def get_start_end(topk_idx: torch.Tensor, sorted_idx: torch.Tensor, return exp_start, exp_end +def _get_sorted_idx(topk_ids: torch.Tensor, num_experts: int): + """get sorted idx.""" + flatten_topk_ids = topk_ids.flatten() + sorted_idx = flatten_topk_ids.argsort() + + exp_start, exp_end = get_start_end(flatten_topk_ids, sorted_idx, + num_experts) + return sorted_idx, exp_start, exp_end + + +def _renormalize(topk_weights: torch.Tensor, renormalize: bool): + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + if not topk_weights.is_contiguous(): + topk_weights = topk_weights.contiguous() + return topk_weights + + +def _make_intermediate(shape: tuple, dtype: torch.dtype, device: torch.device, + zeros: bool): + """make intermediate.""" + if zeros: + return torch.zeros(shape, dtype=dtype, device=device) + else: + return torch.empty(shape, dtype=dtype, device=device) + + def fused_moe(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -283,31 +308,17 @@ def fused_moe(hidden_states: torch.Tensor, """fused moe.""" M = hidden_states.size(0) E, N, _ = w1.shape - full_exp = False if num_experts is None: num_experts = E - elif num_experts == E: - full_exp = True - - def __get_sorted_idx(topk_ids: torch.Tensor): - flatten_topk_ids = topk_ids.flatten() - sorted_idx = flatten_topk_ids.argsort() - - exp_start, exp_end = get_start_end(flatten_topk_ids, sorted_idx, - num_experts) - return sorted_idx, exp_start, exp_end - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - if not topk_weights.is_contiguous(): - topk_weights = topk_weights.contiguous() + full_exp = num_experts == E - sorted_idx, exp_start, exp_end = __get_sorted_idx(topk_ids) + topk_weights = _renormalize(topk_weights, renormalize) + sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts) - if full_exp: - intermediate_cache1 = hidden_states.new_empty((M, topk, N)) - else: - intermediate_cache1 = hidden_states.new_zeros((M, topk, N)) + intermediate_cache1 = _make_intermediate((M, topk, N), + dtype=hidden_states.dtype, + device=hidden_states.device, + zeros=not full_exp) # gate and up fused_moe_kernel_launcher( hidden_states, @@ -331,10 +342,10 @@ def __get_sorted_idx(topk_ids: torch.Tensor): gate_cache = silu_and_mul(intermediate_cache1) gate_cache = gate_cache.unflatten(0, unflat_size) - if full_exp: - intermediate_cache2 = hidden_states.new_empty((M, topk, w2.shape[1])) - else: - intermediate_cache2 = hidden_states.new_zeros((M, topk, w2.shape[1])) + intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + zeros=not full_exp) # down fused_moe_kernel_launcher( gate_cache, diff --git a/lmdeploy/pytorch/kernels/cuda/rms_norm.py b/lmdeploy/pytorch/kernels/cuda/rms_norm.py index bc994012fc..045b55e1ba 100644 --- a/lmdeploy/pytorch/kernels/cuda/rms_norm.py +++ b/lmdeploy/pytorch/kernels/cuda/rms_norm.py @@ -4,8 +4,6 @@ import triton.language as tl from torch import Tensor -from .triton_utils import get_kernel_meta, wrap_jit_func - @triton.jit def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr): @@ -18,15 +16,6 @@ def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr): return out -@wrap_jit_func(type_hint=dict( - input=Tensor, - weight=Tensor, - output=Tensor, - input_row_stride=int, - eps=float, - N_COLS=torch.int32, - BLOCK_N=torch.int32, -)) @triton.jit def rms_norm_kernel(input, weight, output, input_row_stride: tl.constexpr, eps: tl.constexpr, N_COLS: tl.constexpr, @@ -45,18 +34,6 @@ def rms_norm_kernel(input, weight, output, input_row_stride: tl.constexpr, tl.store(out_ptr + offsets, out, mask=offsets < N_COLS) -@wrap_jit_func(type_hint=dict( - input=Tensor, - weight=Tensor, - residual=Tensor, - output=Tensor, - out_residual=Tensor, - input_row_stride=int, - residual_row_stride=int, - eps=float, - N_COLS=torch.int32, - BLOCK_N=torch.int32, -)) @triton.jit def add_rms_norm_kernel(input, weight, residual, output, out_residual, input_row_stride: tl.constexpr, @@ -95,6 +72,7 @@ def rms_norm(hidden_states: Tensor, hidden_states = hidden_states.contiguous() feat_size = weight.shape[0] + assert hidden_states.size(-1) == feat_size seq_len = hidden_states.numel() // hidden_states.size(-1) input_stride = hidden_states.stride(-2) @@ -103,39 +81,40 @@ def rms_norm(hidden_states: Tensor, if out is None: out = torch.empty_like(hidden_states) - kernel_meta = get_kernel_meta(hidden_states) grid = (seq_len, ) if residual is None: - rms_norm_kernel[grid](hidden_states, - weight, - out, - input_row_stride=input_stride, - eps=eps, - N_COLS=feat_size, - BLOCK_N=BLOCK_N, - num_warps=4, - num_stages=2, - **kernel_meta) + rms_norm_kernel[grid]( + hidden_states, + weight, + out, + input_row_stride=input_stride, + eps=eps, + N_COLS=feat_size, + BLOCK_N=BLOCK_N, + num_warps=4, + num_stages=2, + ) return out else: if out_residual is None: out_residual = torch.empty_like(hidden_states) res_stride = residual.stride(-2) - add_rms_norm_kernel[grid](hidden_states, - weight, - residual, - out, - out_residual, - input_row_stride=input_stride, - residual_row_stride=res_stride, - eps=eps, - N_COLS=feat_size, - BLOCK_N=BLOCK_N, - num_warps=4, - num_stages=2, - **kernel_meta) + add_rms_norm_kernel[grid]( + hidden_states, + weight, + residual, + out, + out_residual, + input_row_stride=input_stride, + residual_row_stride=res_stride, + eps=eps, + N_COLS=feat_size, + BLOCK_N=BLOCK_N, + num_warps=4, + num_stages=2, + ) return out, out_residual diff --git a/lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py b/lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py new file mode 100644 index 0000000000..72d9d802a4 --- /dev/null +++ b/lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py @@ -0,0 +1,312 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# modify from: https://github.com/vllm-project/vllm +import torch +import triton +import triton.language as tl + +from .activation import silu_and_mul +from .fused_moe import _get_sorted_idx, _make_intermediate, _renormalize +from .triton_utils import get_kernel_meta +from .w8a8_triton_kernels import per_token_quant_int8 + + +def get_cuda_autotune_config(): + return [ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 1, + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 1, + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 1, + }, + num_stages=4, + num_warps=4), + ] + + +@triton.autotune( + configs=get_cuda_autotune_config(), + key=['N', 'K', 'M_NP2'], + warmup=10, + rep=25, +) +@triton.jit +def fused_moe_w8a8_kernel( + A, + A_scale, + B, + B_scale, + C, + SortedIdx, + ExpStart, + ExpEnd, + Weights, + N: tl.constexpr, + K: tl.constexpr, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_be: tl.constexpr, + stride_bn: tl.constexpr, + stride_bk: tl.constexpr, + stride_bse: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + M_NP2: tl.constexpr, + ENABLE_WEIGHTS: tl.constexpr, + top_k: tl.constexpr, + expert_offset: tl.constexpr, + reindex_a: tl.constexpr, + reindex_c: tl.constexpr, +): + """fused moe kernel.""" + exp_id = tl.program_id(1) + pid = tl.program_id(0) + + exp_start = tl.load(ExpStart + exp_id + expert_offset) + exp_end = tl.load(ExpEnd + exp_id + expert_offset) + M = exp_end - exp_start + if M <= 0: + return + + num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + if GROUP_SIZE_M == 1: + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N: + return + + offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask_sid = offs_sid < exp_end + sid = tl.load(SortedIdx + offs_sid, mask=mask_sid, other=0) + + offs_k = tl.arange(0, BLOCK_SIZE_K) + if reindex_a: + offs_am = sid // top_k + else: + offs_am = offs_sid + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + as_ptrs = A_scale + offs_am + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), + BLOCK_SIZE_N) + + # deepseek has 160 experts, exp index would overflow int32 + exp_id = exp_id.to(tl.int64) + exp_off = stride_be * exp_id + b_ptrs = B + exp_off + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + bs_ptrs = B_scale + exp_id * stride_bse + offs_bn + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=mask_sid[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + accumulator = tl.dot(a, b, acc=accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + ascale = tl.load(as_ptrs, mask=mask_sid) + bscale = tl.load(bs_ptrs) + c = accumulator.to(ascale.dtype) + c = c * ascale[:, None] * bscale[None, :] + + if ENABLE_WEIGHTS: + weight = tl.load(Weights + sid, mask=mask_sid) + c = c * weight[:, None].to(c.dtype) + + c = c.to(C.dtype.element_ty) + + if reindex_c: + offs_cm = sid + else: + offs_cm = offs_sid + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_bn[None, :] + tl.store(c_ptrs, c, mask=mask_sid[:, None]) + + +def fused_moe_w8a8_kernel_launcher( + A: torch.Tensor, + A_scale: torch.Tensor, + B: torch.Tensor, + B_scale: torch.Tensor, + C: torch.Tensor, + sorted_idx: torch.Tensor, + exp_start: torch.Tensor, + exp_end: torch.Tensor, + weights: torch.Tensor, + enable_weights: bool = False, + top_k: int = 1, + num_tokens: int = None, + expert_offset: int = 0, + reindex_a: bool = True, + reindex_c: bool = True, +): + """fused moe kernel launcher.""" + + if num_tokens is None: + num_tokens = A.size(0) + M_NP2 = triton.next_power_of_2(num_tokens) + M_NP2 = max(64, M_NP2) + E, N, K = B.shape + + assert A_scale.is_contiguous() + assert B_scale.is_contiguous() + + def _grid_fn(META): + grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) * + triton.cdiv(N, META['BLOCK_SIZE_N']), E) + return grid + + A = A.flatten(0, -2) + C = C.flatten(0, -2) + + grid = _grid_fn + kernel_meta = get_kernel_meta(A) + fused_moe_w8a8_kernel[grid]( + A, + A_scale, + B, + B_scale, + C, + sorted_idx, + exp_start, + exp_end, + weights, + N=N, + K=K, + stride_am=A.stride(0), + stride_ak=A.stride(1), + stride_be=B.stride(0), + stride_bn=B.stride(1), + stride_bk=B.stride(2), + stride_bse=B_scale.stride(0), + stride_cm=C.stride(0), + stride_cn=C.stride(1), + ENABLE_WEIGHTS=enable_weights, + top_k=top_k, + expert_offset=expert_offset, + reindex_a=reindex_a, + reindex_c=reindex_c, + M_NP2=M_NP2, + **kernel_meta, + ) + + +def fused_moe_w8a8(input: torch.Tensor, + input_scale: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + topk: int, + out_dtype: torch.dtype = torch.float16, + expert_offset: int = 0, + num_experts: int = None, + renormalize: bool = False) -> torch.Tensor: + """fused moe.""" + device = input.device + M = input.size(0) + E, N, _ = w1.shape + if num_experts is None: + num_experts = E + full_exp = num_experts == E + + topk_weights = _renormalize(topk_weights, renormalize) + sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts) + + intermediate_cache1 = _make_intermediate((M, topk, N), + dtype=out_dtype, + device=device, + zeros=not full_exp) + # gate and up + fused_moe_w8a8_kernel_launcher( + input, + input_scale, + w1, + w1_scale, + intermediate_cache1, + sorted_idx=sorted_idx, + exp_start=exp_start, + exp_end=exp_end, + weights=topk_weights, + enable_weights=False, + top_k=topk, + num_tokens=M, + expert_offset=expert_offset, + reindex_a=True, + reindex_c=False, + ) + + # activate + unflat_size = intermediate_cache1.shape[:-1] + intermediate_cache1 = intermediate_cache1.flatten(0, -2) + gate_cache = silu_and_mul(intermediate_cache1) + del intermediate_cache1 + gate_cache = gate_cache.unflatten(0, unflat_size) + gate_cache, gate_scale = per_token_quant_int8(gate_cache, 1e-7) + + intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]), + dtype=out_dtype, + device=device, + zeros=not full_exp) + # down + fused_moe_w8a8_kernel_launcher( + gate_cache, + gate_scale, + w2, + w2_scale, + intermediate_cache2, + sorted_idx=sorted_idx, + exp_start=exp_start, + exp_end=exp_end, + weights=topk_weights, + enable_weights=True, + top_k=1, + num_tokens=M, + expert_offset=expert_offset, + reindex_a=False, + reindex_c=True, + ) + + ret = intermediate_cache2.sum(dim=1) + return ret diff --git a/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py b/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py index 0d0e10ec83..a8eeb63a5f 100644 --- a/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py +++ b/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py @@ -14,14 +14,13 @@ tl_round = tl.math.round -def per_channel_quant(x, n_bits, dtype): +def per_channel_quant(x: torch.Tensor, dtype: torch.dtype): """Quantize the input tensor 'x' channel-wise using the given number of bits. Args: x (torch.Tensor): The input tensor to be quantized. Must be a 2-dimensional tensor. - n_bits (int): The number of bits to use for quantization. dtype (torch.dtype): The data type to which the quantized tensor should be converted. @@ -32,31 +31,40 @@ def per_channel_quant(x, n_bits, dtype): assert x.ndim == 2 x = x.to(torch.float32) x_absmax = x.view(x.shape[0], -1).abs().max(dim=1, keepdim=True)[0] - q_max = 2**(n_bits - 1) - 1 - q_min = -2**(n_bits - 1) - scale = x_absmax / (2**(n_bits - 1) - 1) - x_q = torch.round(x / scale).clamp(q_min, q_max).to(dtype) + qtype_info = torch.finfo( + dtype) if dtype.is_floating_point else torch.iinfo(dtype) + q_max = qtype_info.max + q_min = qtype_info.min + scale = x_absmax / q_max + x_q = x / scale + if not dtype.is_floating_point: + x_q = torch.round(x_q) + x_q = x_q.clamp(q_min, q_max).to(dtype) return x_q, scale @triton.autotune( configs=[ triton.Config({ - 'BLOCK_N': 64, + 'BLOCK_M': 128, + 'BLOCK_N': 256, 'BLOCK_K': 128, }, - num_stages=4, - num_warps=4), + num_stages=3, + num_warps=8), triton.Config({ + 'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, }, - num_stages=4, - num_warps=4) + num_stages=3, + num_warps=8) ], key=['N', 'K'], + warmup=5, + rep=20, ) -@triton.jit +@triton.jit(do_not_specialize=['M']) def _linear( A, B, @@ -76,6 +84,7 @@ def _linear( GROUP_SIZE_M: tl.constexpr, rms_scale_ptr, linear_scale_ptr, + ACCUMULATOR_DTYPE: tl.constexpr, ): """Triton-accelerated function used to perform linear operations (dot product) on input tensors `A` and `B`, and store the result in output @@ -100,12 +109,11 @@ def _linear( offs_k = tl.arange(0, BLOCK_K) a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACCUMULATOR_DTYPE) for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0) - accumulator += tl.dot(a, b) + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=None) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=None) + accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE) a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk c = accumulator.to(tl.float32) @@ -124,42 +132,31 @@ def _linear( @triton.autotune( configs=[ triton.Config({ - 'BLOCK_N': 64, + 'BLOCK_M': 128, + 'BLOCK_N': 256, 'BLOCK_K': 128, }, - num_stages=4, - num_warps=4), + num_stages=3, + num_warps=8), triton.Config({ + 'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, }, - num_stages=4, - num_warps=4) + num_stages=3, + num_warps=8) ], key=['N', 'K'], + warmup=5, + rep=20, ) -@triton.jit -def _linear_add( - A, - B, - C, - residual_ptr, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - rms_scale_ptr, - linear_scale_ptr, -): +@triton.jit(do_not_specialize=['M']) +def _linear_add(A, B, C, residual_ptr, M, N, K, stride_am, stride_ak, + stride_bk, stride_bn, stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + rms_scale_ptr, linear_scale_ptr, + ACCUMULATOR_DTYPE: tl.constexpr): """Triton-accelerated function used to perform a linear operation (dot product) on input tensors `A` and `B`, with addition of residual. @@ -183,11 +180,11 @@ def _linear_add( a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACCUMULATOR_DTYPE) for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0) - accumulator += tl.dot(a, b) + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=None) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=None) + accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE) a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk c = accumulator.to(tl.float32) @@ -231,14 +228,11 @@ def matmul_kernel_dynamic_quant(a, assert residual.shape == c_shape assert residual.is_contiguous() c = a.new_empty(c_shape, dtype=output_dtype) - - BLOCK_M = 128 - if M < BLOCK_M: - BLOCK_M = triton.next_power_of_2(M) - BLOCK_M = max(BLOCK_M, 16) + accumulator_dtype = tl.float32 if a.is_floating_point() else tl.int32 def grid(META): - return (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, META['BLOCK_N']), ) + return (triton.cdiv(M, META['BLOCK_M']) * + triton.cdiv(N, META['BLOCK_N']), ) kernel_meta = get_kernel_meta(a) if residual is not None: @@ -255,10 +249,10 @@ def grid(META): b.stride(0), c.stride(-2), c.stride(-1), - BLOCK_M=BLOCK_M, GROUP_SIZE_M=8, rms_scale_ptr=rms_scale, linear_scale_ptr=linear_scale, + ACCUMULATOR_DTYPE=accumulator_dtype, **kernel_meta) else: _linear[grid](a, @@ -273,10 +267,10 @@ def grid(META): b.stride(0), c.stride(-2), c.stride(-1), - BLOCK_M=BLOCK_M, GROUP_SIZE_M=8, rms_scale_ptr=rms_scale, linear_scale_ptr=linear_scale, + ACCUMULATOR_DTYPE=accumulator_dtype, **kernel_meta) if bias is not None: c += bias @@ -286,13 +280,16 @@ def grid(META): @triton.jit def _per_token_quant_int8( - y_ptr, - y_q_ptr, - y_s_ptr, - y_stride, - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK: tl.constexpr, + y_ptr, + y_q_ptr, + y_s_ptr, + y_stride: tl.constexpr, + yq_stride: tl.constexpr, + N, # number of columns in X + eps: tl.constexpr, # epsilon to avoid division by zero + BLOCK: tl.constexpr, + Q_MAX: tl.constexpr, + IS_FLOATING_POINT: tl.constexpr, # True for floating point dtype ): """A Triton-accelerated function to perform per-token quantization on a tensor. @@ -302,7 +299,7 @@ def _per_token_quant_int8( # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) y_ptr += row * y_stride - y_q_ptr += row * y_stride + y_q_ptr += row * yq_stride y_s_ptr += row cols = tl.arange(0, BLOCK) # N <= BLOCK @@ -311,21 +308,26 @@ def _per_token_quant_int8( y = tl.load(y_ptr + cols, mask=mask, other=0.).to(tl.float32) # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - y_s = _absmax / 127 - y_q = tl_round(y / y_s).to(tl.int8) + y_s = _absmax / Q_MAX + y_q = y / y_s + if not IS_FLOATING_POINT: + y_q = tl_round(y_q).to(tl.int8) tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_s_ptr, y_s) -def per_token_quant_int8(x, eps): +def per_token_quant_int8(x, eps, quant_dtype=torch.int8): """Function to perform per-token quantization on an input tensor `x`. It converts the tensor values into signed 8-bit integers and returns the quantized tensor along with the scaling factor used for quantization. """ - - x_q = torch.empty_like(x, device=x.device, dtype=torch.int8) + qdtype_info = torch.finfo( + quant_dtype) if quant_dtype.is_floating_point else torch.iinfo( + quant_dtype) + q_max = qdtype_info.max + x_q = torch.empty_like(x, device=x.device, dtype=quant_dtype) M = x.numel() // x.shape[-1] N = x.shape[-1] x_s = torch.empty(x.shape[:-1] + (1, ), @@ -334,94 +336,184 @@ def per_token_quant_int8(x, eps): BLOCK = triton.next_power_of_2(N) # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) + + if x.dim() > 2: + x = x.flatten(0, -2) + assert x.stride(-1) == 1 # enqueue kernel kernel_meta = get_kernel_meta(x) - _per_token_quant_int8[(M, )](x, - x_q, - x_s, - x.stride(-2), - N, - eps, - BLOCK=BLOCK, - num_warps=num_warps, - **kernel_meta) + _per_token_quant_int8[(M, )]( + x, + x_q, + x_s, + y_stride=x.stride(-2), + yq_stride=x_q.stride(-2), + N=N, + eps=eps, + BLOCK=BLOCK, + Q_MAX=q_max, + IS_FLOATING_POINT=quant_dtype.is_floating_point, + num_warps=num_warps, + **kernel_meta) return x_q, x_s @triton.jit -def _rms_norm_fwd_fused_dynamic_symmetric( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - Scale, # pointer to the scales of the output activation - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK_SIZE: tl.constexpr, +def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr): + """compute rms norm.""" + xf = x.to(tl.float32) + + var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS) + out = xf * tl.math.rsqrt(var + eps) + out = (w * out).to(x.dtype) + return out + + +@triton.jit +def rms_norm_quant_kernel( + input, + weight, + output, + out_scale, + input_row_stride: tl.constexpr, + eps: tl.constexpr, + N_COLS: tl.constexpr, + BLOCK_N: tl.constexpr, + Q_MIN: tl.constexpr, + Q_MAX: tl.constexpr, + IS_FLOATING_POINT: tl.constexpr, ): - """A Triton kernel that calculates Root Mean Square (RMS) normalization - with fused dynamic symmetric quantization.""" - row = tl.program_id(0) - Y += row * stride - X += row * stride + """rms norm kernel.""" + prog_id = tl.program_id(0) + offsets = tl.arange(0, BLOCK_N) - cols = tl.arange(0, BLOCK_SIZE) - mask = cols < N - x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) - _var = x * x - var = tl.sum(_var, axis=0) / N - rstd = tl.math.rsqrt(var + eps) + w = tl.load(weight + offsets, mask=offsets < N_COLS) + + x_ptr = input + prog_id * input_row_stride + x = tl.load(x_ptr + offsets, mask=offsets < N_COLS) + out = _compute_rms_norm(x, w, eps, N_COLS) + + scale = tl.max(tl.abs(out)).to(tl.float32) / Q_MAX + out_s_ptr = out_scale + prog_id + tl.store(out_s_ptr, scale) + out = out / scale + if not IS_FLOATING_POINT: + out = tl_round(out) + out = tl.clamp(out, Q_MIN, Q_MAX) + out_ptr = output + prog_id * input_row_stride + tl.store(out_ptr + offsets, out, mask=offsets < N_COLS) - w = tl.load(W + cols, mask=mask) - x_hat = x * rstd - y = x_hat * w - scale = tl.max(tl.abs(y)).to(tl.float32) / 127 - tl.store(Scale + row, scale) +@triton.jit +def add_rms_norm_quant_kernel( + input, + weight, + residual, + output, + out_scale, + out_residual, + input_row_stride: tl.constexpr, + residual_row_stride: tl.constexpr, + eps: tl.constexpr, + N_COLS: tl.constexpr, + BLOCK_N: tl.constexpr, + Q_MIN: tl.constexpr, + Q_MAX: tl.constexpr, + IS_FLOATING_POINT: tl.constexpr, +): + """rms norm kernel.""" + prog_id = tl.program_id(0) + offsets = tl.arange(0, BLOCK_N) + + w = tl.load(weight + offsets, mask=offsets < N_COLS) - y = tl_round(y / scale) - y = tl.minimum(y, 127) - y = tl.maximum(y, -128) - tl.store(Y + cols, y, mask=mask) + x_ptr = input + prog_id * input_row_stride + x = tl.load(x_ptr + offsets, mask=offsets < N_COLS) + res_ptr = residual + prog_id * residual_row_stride + res = tl.load(res_ptr + offsets, mask=offsets < N_COLS) -def rms_norm_dynamic_quant(x, w, eps): + new_x = x + res + out_res_ptr = out_residual + prog_id * residual_row_stride + tl.store(out_res_ptr + offsets, new_x, mask=offsets < N_COLS) + + out = _compute_rms_norm(new_x, w, eps, N_COLS) + + scale = tl.max(tl.abs(out)).to(tl.float32) / Q_MAX + out_s_ptr = out_scale + prog_id + tl.store(out_s_ptr, scale) + out = out / scale + if not IS_FLOATING_POINT: + out = tl_round(out) + out = tl.clamp(out, Q_MIN, Q_MAX) + out_ptr = output + prog_id * input_row_stride + tl.store(out_ptr + offsets, out, mask=offsets < N_COLS) + + +def rms_norm_dynamic_quant(x, w, eps, residual=None, quant_dtype=torch.int8): """Performs RMS normalization with dynamic quantization. The function reshapes the input tensor `x`, creates an empty tensor `y` with the same shape as `x`, and calculates RMS normalization on the - reshaped `x` using a Triton kernel `_rms_norm_fwd_fused_dynamic_symmetric`. + reshaped `x` using a Triton kernel `rms_norm_quant_kernel`. """ - - x_arg = x.flatten(0, -2) - y = torch.empty_like(x, dtype=torch.int8) - M, K = x_arg.shape - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(K)) - if K > BLOCK_SIZE: - raise RuntimeError( - "This rms norm doesn't support feature dim >= 64KB.") - num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + qdtype_info = torch.finfo( + quant_dtype) if quant_dtype.is_floating_point else torch.iinfo( + quant_dtype) + y = torch.empty_like(x, dtype=quant_dtype) scale = x.new_empty(x.shape[:-1] + (1, ), dtype=torch.float32) - kernel_meta = get_kernel_meta(x_arg) - _rms_norm_fwd_fused_dynamic_symmetric[(M, )](x_arg, - y, - w, - scale, - x_arg.stride(0), - K, - eps, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - **kernel_meta) - return y, scale + + feat_size = w.shape[0] + seq_len = x.numel() // x.size(-1) + input_stride = x.stride(-2) + BLOCK_N = triton.next_power_of_2(feat_size) + grid = (seq_len, ) + + if residual is None: + rms_norm_quant_kernel[grid]( + x, + w, + y, + scale, + input_row_stride=input_stride, + eps=eps, + N_COLS=feat_size, + BLOCK_N=BLOCK_N, + Q_MIN=qdtype_info.min, + Q_MAX=qdtype_info.max, + IS_FLOATING_POINT=quant_dtype.is_floating_point, + num_warps=4, + num_stages=2) + return y, scale + else: + out_residual = torch.empty_like(x) + res_stride = residual.stride(-2) + add_rms_norm_quant_kernel[grid]( + x, + w, + residual, + y, + scale, + out_residual, + input_row_stride=input_stride, + residual_row_stride=res_stride, + eps=eps, + N_COLS=feat_size, + BLOCK_N=BLOCK_N, + Q_MIN=qdtype_info.min, + Q_MAX=qdtype_info.max, + IS_FLOATING_POINT=quant_dtype.is_floating_point, + num_warps=4, + num_stages=2) + return y, scale, out_residual def test_rms_and_linear(x, rms_weight, linear_weight, - dtype=torch.float16, + output_dtype=torch.float16, + quant_dtype=torch.int8, eps=1e-5): """Test quantized rms norm and quantized linear layer.""" @@ -434,15 +526,18 @@ def linear_torch(x, b): return F.linear(x, b) linear_weight_quant, linear_scale = per_channel_quant( - linear_weight, 8, torch.int8) + linear_weight, quant_dtype) - rms_out, rms_scale = rms_norm_dynamic_quant(x, rms_weight, eps) + rms_out, rms_scale = rms_norm_dynamic_quant(x, + rms_weight, + eps, + quant_dtype=quant_dtype) assert rms_out.shape == x.shape and rms_scale.shape[:-1] == x.shape[:-1] linear_out = matmul_kernel_dynamic_quant(rms_out, linear_weight_quant, rms_scale, linear_scale, - output_dtype=dtype) + output_dtype=output_dtype) rms_out_torch = rms_norm_torch(x, rms_weight, eps).half() linear_out_torch = linear_torch(rms_out_torch, linear_weight) @@ -456,17 +551,26 @@ def linear_torch(x, b): linear_out_torch.flatten().to(torch.float32))) -def test_per_token_quant(x, eps): +def test_per_token_quant(x, eps, quant_dtype=torch.int8): """Test per-token quantization.""" - def per_token_quant_int8_torch(x, eps): + def per_token_quant_int8_torch(x, eps, quant_dtype): + qdtype_info = torch.finfo( + quant_dtype) if quant_dtype.is_floating_point else torch.iinfo( + quant_dtype) + _absmax = torch.clamp(x.abs().max(dim=-1, keepdim=True)[0], min=eps) - x_s = _absmax / 127 - x_q = torch.clamp((x / x_s).round(), min=-128, max=127) + x_s = _absmax / qdtype_info.max + x_q = x / x_s + if not quant_dtype.is_floating_point: + x_q = x_q.round() + x_q = torch.clamp(x_q, min=qdtype_info.min, max=qdtype_info.max) return x_q, x_s - x_q, x_s = per_token_quant_int8(x, eps) - x_q_torch, x_s_torch = per_token_quant_int8_torch(x, eps) + x_q, x_s = per_token_quant_int8(x, eps, quant_dtype=quant_dtype) + x_q_torch, x_s_torch = per_token_quant_int8_torch(x, + eps, + quant_dtype=quant_dtype) assert x_q.shape == x_q_torch.shape and x_s.shape == x_s_torch.shape cos = torch.nn.CosineSimilarity(0) print( @@ -479,21 +583,11 @@ def per_token_quant_int8_torch(x, eps): x_s_torch.flatten().to(torch.float32))) -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=['M'], - x_vals=[1, 16, 32, 64, 128, 256] + [512 * i * 2 for i in range(1, 17)], - line_arg='provider', - line_vals=['int8_dynamic_triton_op', 'float_torch'], - line_names=['int8_dynamic_triton_op', 'float_torch'], - styles=[('blue', '-'), ('green', '-'), ('orange', '-'), - ('yellow', '-'), ('yellow', '-')], - ylabel='GB/s', - plot_name='forward', - args={ - 'dtype': torch.float16, - })) -def bench_rms_and_linear(M, dtype, provider, eps=1e-5, device='cuda'): +def bench_rms_and_linear(M: int, + provider: str, + dtype: torch.dtype = torch.float16, + eps: float = 1e-5): + """benchmark rms and linear.""" def rms_norm_torch(x, w, eps): variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) @@ -505,6 +599,7 @@ def linear_torch(x, b): N = 4096 K = 4096 + x_shape = (M, K) rms_w_shape = (x_shape[-1], ) rms_weight = torch.randn(rms_w_shape, @@ -516,14 +611,33 @@ def linear_torch(x, b): dtype=dtype, device='cuda', requires_grad=True) - linear_weight_quant, linear_scale = per_channel_quant( - linear_weight, 8, torch.int8) - alpha = max(x.max().abs(), x.min().abs()) - rms_scale = alpha / 127 + if provider == 'torch_fp16': + rms_out_torch = rms_norm_torch(x, rms_weight, eps).half() - if provider == 'int8_dynamic_triton_op': - rms_out, rms_scale = rms_norm_dynamic_quant(x, rms_weight, eps) + def y_fwd(): + linear_torch(rms_out_torch, linear_weight) + else: + if provider == 'triton_int8': + quant_dtype = torch.int8 + elif provider == 'triton_fp8_e4m3': + quant_dtype = torch.float8_e4m3fn + elif provider == 'triton_fp8_e5m2': + quant_dtype = torch.float8_e5m2 + + linear_weight_quant, linear_scale = per_channel_quant( + linear_weight, quant_dtype) + + alpha = max(x.max().abs(), x.min().abs()) + if quant_dtype.is_floating_point: + qdtype_info = torch.finfo(quant_dtype) + else: + qdtype_info = torch.iinfo(quant_dtype) + rms_scale = alpha / qdtype_info.max + rms_out, rms_scale = rms_norm_dynamic_quant(x, + rms_weight, + eps, + quant_dtype=quant_dtype) def y_fwd(): @@ -532,21 +646,22 @@ def y_fwd(): rms_scale, linear_scale, output_dtype=dtype) - elif provider == 'float_torch': - rms_out_torch = rms_norm_torch(x, rms_weight, eps).half() - - def y_fwd(): - linear_torch(rms_out_torch, linear_weight) quantiles = [0.5, 0.2, 0.8] ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500) - return ms, max_ms, min_ms + + def perf(ms): + return 2 * M * N * K * 1e-12 / (ms * 1e-3) + + return perf(ms), perf(max_ms), perf(min_ms) if __name__ == '__main__': torch.manual_seed(0) + device_map = torch.cuda.get_device_capability() + is_fp8_supported = device_map[0] >= 9 dtype = torch.float16 # test (bs, seq_len, dim) x (dim, out_dim) x = torch.randn((2, 2048, 4096), dtype=dtype, device='cuda') @@ -559,7 +674,16 @@ def y_fwd(): dtype=dtype, device='cuda', requires_grad=True) - test_rms_and_linear(x, rms_weight, linear_weight) + test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.int8) + if is_fp8_supported: + test_rms_and_linear(x, + rms_weight, + linear_weight, + quant_dtype=torch.float8_e4m3fn) + test_rms_and_linear(x, + rms_weight, + linear_weight, + quant_dtype=torch.float8_e5m2) # test (M, K) x (K, N) x = torch.randn((4, 4096), dtype=dtype, device='cuda') @@ -572,11 +696,45 @@ def y_fwd(): dtype=dtype, device='cuda', requires_grad=True) - test_rms_and_linear(x, rms_weight, linear_weight) + test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.int8) + if is_fp8_supported: + test_rms_and_linear(x, + rms_weight, + linear_weight, + quant_dtype=torch.float8_e4m3fn) + test_rms_and_linear(x, + rms_weight, + linear_weight, + quant_dtype=torch.float8_e5m2) # test per-token quant x = torch.randn((4, 2048, 4096), dtype=dtype, device='cuda') eps = 1e-7 - test_per_token_quant(x, eps) - - bench_rms_and_linear.run(print_data=True) + test_per_token_quant(x, eps, quant_dtype=torch.int8) + if is_fp8_supported: + test_per_token_quant(x, eps, quant_dtype=torch.float8_e4m3fn) + test_per_token_quant(x, eps, quant_dtype=torch.float8_e5m2) + + # benchmark triton kernels + line_vals = ['triton_int8', 'torch_fp16'] + line_names = ['triton_int8', 'torch_fp16'] + + if is_fp8_supported: + line_vals += ['triton_fp8_e4m3', 'triton_fp8_e5m2'] + line_names += ['triton_fp8_e4m3', 'triton_fp8_e5m2'] + config = triton.testing.Benchmark(x_names=['M'], + x_vals=[1, 16, 32, 64, 128, 256] + + [512 * i * 2 for i in range(1, 5)], + line_arg='provider', + line_vals=line_vals, + line_names=line_names, + styles=[('blue', '-'), ('green', '-'), + ('orange', '-'), ('black', '-'), + ('yellow', '-')], + ylabel='TFLOPS', + plot_name='bench-triton', + args={ + 'dtype': torch.float16, + }) + bench_funch = (triton.testing.perf_report(config))(bench_rms_and_linear) + bench_funch.run(print_data=True) diff --git a/lmdeploy/pytorch/kernels/dlinfer/__init__.py b/lmdeploy/pytorch/kernels/dlinfer/__init__.py index 8f86f0019a..fe82010761 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/__init__.py +++ b/lmdeploy/pytorch/kernels/dlinfer/__init__.py @@ -3,6 +3,7 @@ from .apply_rotary_pos_emb import apply_rotary_pos_emb from .awq_kernels import awq_linear from .fill_kv_cache import fill_kv_cache +from .flash_attention import flash_attention_fwd from .fused_moe import fused_moe from .linear import linear from .moe_gating_topk_softmax import moe_gating_topk_softmax @@ -16,6 +17,7 @@ 'fill_kv_cache', 'fused_moe', 'paged_attention_fwd', + 'flash_attention_fwd', 'linear', 'moe_gating_topk_softmax', 'multinomial_sampling', diff --git a/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py b/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py index fb2eee9d41..63564d7ed8 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py +++ b/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + import dlinfer.ops as ext_ops from torch import Tensor @@ -9,7 +11,16 @@ def fill_kv_cache( key_caches: Tensor, value_caches: Tensor, kv_start_indices: Tensor, + k_scales_zeros: Sequence[Optional[Tensor]], + v_scales_zeros: Sequence[Optional[Tensor]], + quant_bits: int = 0, ): """fill key/value state to cache for paged attention.""" - return ext_ops.fill_kv_cache(key_states, value_states, key_caches, - value_caches, kv_start_indices) + return ext_ops.fill_kv_cache(key_states, + value_states, + key_caches, + value_caches, + kv_start_indices, + k_scales_zeros=k_scales_zeros, + v_scales_zeros=v_scales_zeros, + quant_bits=quant_bits) diff --git a/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py b/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py new file mode 100644 index 0000000000..1788f947ee --- /dev/null +++ b/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +from dlinfer.utils.type_annotation import Tensor + + +def flash_attention_fwd( + query_states: Tensor, + key_states: Tensor, + value_states: Tensor, + attn_output: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_start_loc: Tensor, + kv_seqlens: Tensor, + max_q_seqlen: int = None, + window_size: int = None, + sm_scale: float = None, + logit_softcapping: float = None, + causal: bool = True, +): + num_q_heads = query_states.shape[1] + num_kv_heads = value_states.shape[1] + return ext_ops.prefill_attention( + query_states, + key_states, + value_states, + q_start_loc, + q_seqlens, + max_q_seqlen, + num_q_heads, + num_kv_heads, + attn_mask=None, + softmax_scale=sm_scale, + attn_output=attn_output, + ) diff --git a/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py b/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py index 72bab2d720..275ea65261 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py +++ b/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py @@ -5,12 +5,13 @@ def fused_moe( hidden_states: Tensor, - top_k: int, - topk_ids: Tensor, - topk_weights: Tensor, gate_up_weights: Tensor, down_weights: Tensor, + topk_weights: Tensor, + topk_ids: Tensor, + topk: int, + renormalize: bool, ): - """ascend fused moe.""" - return ext_ops.fused_moe(hidden_states, top_k, topk_ids, topk_weights, - gate_up_weights, down_weights) + """dlinfer fused moe.""" + return ext_ops.fused_moe(hidden_states, gate_up_weights, down_weights, + topk_weights, topk_ids, topk, renormalize) diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index 47bcb0cfff..ded85d476d 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -19,6 +19,9 @@ def prefill_attention( block_size: int, attn_mask: Sequence[Optional[Tensor]], is_unpaged_prefill: Optional[bool], + kv_scales: Optional[Tensor], + kv_zeros: Optional[Tensor], + quant_bits: Optional[int], ) -> Tensor: num_q_heads = query_states.shape[1] num_kv_heads = value_states.shape[1] @@ -53,11 +56,25 @@ def prefill_attention( num_kv_heads, attn_mask, attn_output=attn_output, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, ) -def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, - max_kv_seq_len, block_offsets, block_size): +def paged_token_attention( + q, + k_cache, + v_cache, + attn_output, + kv_seq_len, + max_kv_seq_len, + block_offsets, + block_size, + kv_scales: Optional[Tensor], + kv_zeros: Optional[Tensor], + quant_bits: Optional[int], +): num_q_heads, q_head_dim = q.shape[1:3] num_kv_heads = k_cache.shape[-1] // q_head_dim return ext_ops.paged_decode_attention( @@ -71,6 +88,9 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, num_q_heads, num_kv_heads, attn_output=attn_output, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, ) @@ -91,6 +111,9 @@ def paged_attention_fwd( block_size: int, attn_mask: Sequence[Optional[Tensor]] = (), is_unpaged_prefill: Optional[bool] = None, + kv_scales: Optional[Tensor] = None, + kv_zeros: Optional[Tensor] = None, + quant_bits: Optional[int] = 0, ): if not is_decoding: return prefill_attention( @@ -108,6 +131,9 @@ def paged_attention_fwd( block_size, attn_mask, is_unpaged_prefill, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, ) else: return paged_token_attention( @@ -119,4 +145,7 @@ def paged_attention_fwd( max_kv_seq_len, block_offsets, block_size, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, ) diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index b16a78f1f4..968b71fee1 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -8,6 +8,7 @@ from torch import Tensor from lmdeploy.messages import GenerationConfig, LogitsProcessor +from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs from lmdeploy.utils import get_logger from .block import LogicalTokenBlocks @@ -205,10 +206,9 @@ def add_sequence( sampling_param: SamplingParam = None, adapter_name: str = None, return_logits: bool = False, - input_embeddings: List[InputEmbeddings] = None, - mrope_position_ids: Tensor = None, - mrope_position_delta: Tensor = None, - cross_attention_states: Tensor = None) -> 'SchedulerSequence': + multimodals: MultiModalInputs = None, + input_embeddings: List[InputEmbeddings] = None + ) -> 'SchedulerSequence': """Add a new message.""" if isinstance(token_ids, Tensor): token_ids = token_ids.numpy() @@ -228,10 +228,8 @@ def add_sequence( adapter_name=adapter_name, arrive_time=time.time(), history_embeddings=HistoryEmbeddings(input_embeddings), + history_multimodals=HistoryMultiModals(multimodals), return_logits=return_logits, - mrope_position_ids=mrope_position_ids, - mrope_position_delta=mrope_position_delta, - cross_attention_states=cross_attention_states, ) self.sequences[seq.seq_id] = seq if self.seq_manager is not None: @@ -361,6 +359,66 @@ def copy(self): return self.clone() +class HistoryMultiModals: + + def __init__(self, multimodals: MultiModalInputs): + if multimodals is None: + multimodals = dict() + self.multimodals = multimodals + + def get_datas(self, start=0, end=-1): + """get multimodals from prompts position [start, end).""" + outs = dict() + test_range = range(start, end) + for modal_type, modal_datas in self.multimodals.items(): + data = [] + for modal_data in modal_datas: + if (modal_data.start not in test_range + and modal_data.end not in test_range): + continue + data.append(modal_data) + if len(data) > 0: + outs[modal_type] = data + return outs + + def add_inputs(self, input_mms: MultiModalInputs): + """add new inputs.""" + for modal_type, vals in input_mms.items(): + if modal_type in self.multimodals: + self.multimodals[modal_type] += vals + else: + self.multimodals[modal_type] = vals + + def empty(self): + if len(self.multimodals) == 0: + return 0 + + return all(len(vals) == 0 for vals in self.multimodals) + + @staticmethod + def update_multimodals(input_mms: MultiModalInputs, prev_len: int): + """update multimodals.""" + for vals in input_mms.values(): + for val in vals: + val.start += prev_len + val.end += prev_len + return input_mms + + def get_encoder_len(self, start=0, end=-1): + """get lens of encoder.""" + test_range = range(start, end) + out_len = 0 + for _, modal_datas in self.multimodals.items(): + for modal_data in modal_datas: + if modal_data.encoder_len is None: + continue + if (modal_data.start not in test_range + and modal_data.end not in test_range): + continue + out_len += modal_data.encoder_len + return out_len + + @dataclass class SchedulerSequence: """Scheduler message.""" @@ -369,12 +427,12 @@ class SchedulerSequence: history_cache: HistoryTokenIds = field(default_factory=HistoryTokenIds) history_embeddings: HistoryEmbeddings = field( default_factory=HistoryEmbeddings) + history_multimodals: HistoryMultiModals = field( + default_factory=HistoryMultiModals) num_new_tokens: int = 0 sampling_param: SamplingParam = field(default_factory=SamplingParam) logical_blocks: LogicalTokenBlocks = field( default_factory=LogicalTokenBlocks) - sender_id: int = -1 - req_id: int = -1 adapter_name: str = None arrive_time: float = 0.0 meta: Any = None @@ -382,10 +440,7 @@ class SchedulerSequence: random_offsets: int = 0 _status: MessageStatus = field(default=MessageStatus.WAITING, init=False) num_ignored_history: int = 0 - mrope_position_ids: Optional[Tensor] = None - mrope_position_delta: Optional[int] = None - cross_attention_states: Optional[Tensor] = None - history_cross_kv_seqlens: int = 0 + model_meta: Dict[str, Any] = None def __post_init__(self): """post init.""" @@ -394,6 +449,10 @@ def __post_init__(self): self._num_images: int = len(self.history_embeddings) self._num_token_ids: int = len(self.history_cache) + self._num_history_cross: int = 0 + self._num_cross: int = self.history_multimodals.get_encoder_len( + 0, self._num_token_ids) + @property def block_size(self) -> int: """block size.""" @@ -464,6 +523,16 @@ def num_all_ids(self): """num all tokens.""" return self.history_len + self._num_token_ids + @property + def num_cross(self): + """num cross.""" + return self._num_cross + + @property + def num_history_cross(self): + """num history cross.""" + return self._num_history_cross + @property def num_blocks(self): """num blocks.""" @@ -489,22 +558,22 @@ def num_all_tokens(self): def num_all_cross_tokens(self): """num of all cross tokens.""" - if self.cross_attention_states is None: - self.history_cross_kv_seqlens = 0 - else: - self.history_cross_kv_seqlens = self.cross_attention_states.shape[ - -2] - return self.history_cross_kv_seqlens + return self._num_cross + self._num_history_cross + + def get_input_multimodals(self): + """get input multimodals.""" + start = self.num_history_ids + end = self.num_all_ids + return self.history_multimodals.get_datas(start, end) def update_token_ids(self, token_ids: Tensor, + multimodals: MultiModalInputs = None, embeddings: List[InputEmbeddings] = None, - cross_attention_states: List[Tensor] = None): + model_meta: Dict[str, Any] = None): """Update token ids, old token ids will be added to history.""" - # cross attention - if cross_attention_states is not None: - self.history_cross_kv_seqlens += cross_attention_states.shape[-2] - self.cross_attention_states = cross_attention_states + old_num_history_ids = self._num_history_ids + self._num_history_ids += self._num_token_ids # update history image nums self._num_history_images += self._num_images @@ -516,6 +585,23 @@ def update_token_ids(self, self._num_images = len(new_embeddings) self.history_embeddings.append(new_embeddings) + # update multimodals + if multimodals is not None: + multimodals = HistoryMultiModals.update_multimodals( + multimodals, self.num_all_ids) + self.history_multimodals.add_inputs(multimodals) + + # cross + self._num_history_cross += self._num_cross + if multimodals is not None: + self._num_cross = self.history_multimodals.get_encoder_len( + old_num_history_ids, self._num_history_ids) + else: + self._num_cross = 0 + + if model_meta is not None: + self.model_meta = model_meta + if isinstance(token_ids, Tensor): token_ids = token_ids.numpy() elif not isinstance(token_ids, np.ndarray): @@ -539,3 +625,12 @@ def set_step(self, step: int): self._num_history_ids = step self._num_token_ids = num_all_ids - step self.num_ignored_history = min(step, self.num_ignored_history) + + self.model_meta = None + + # cross + if self.history_multimodals is not None: + self._num_history_cross = self.history_multimodals.get_encoder_len( + 0, self.num_history_ids) + self._num_cross = self.history_multimodals.get_encoder_len( + self._num_history_ids, num_all_ids) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 0ae5dd7986..e984e39abe 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -4,47 +4,19 @@ from typing import Any, Dict, List, Literal import torch +from torch import distributed as dist from lmdeploy.pytorch.backends import get_backend +from lmdeploy.pytorch.config import ModelConfig +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor -@dataclass -class MRopeModelInputs: - """Multimodal rotary position inputs.""" - position_ids: List[torch.LongTensor] = None - deltas: List[torch.LongTensor] = None - - def get_inputs(self, history_lengths: torch.Tensor, - seq_lengths: torch.Tensor): - mrope_position_ids = [] - for (his_len, seq_len, pos_ids, - delta) in zip(history_lengths, seq_lengths, self.position_ids, - self.deltas): - assert pos_ids.dim() == 2, 'invalid mrope_position_ids' - if his_len + seq_len <= pos_ids.shape[1]: - mrope_position_ids.append(pos_ids[:, - his_len:his_len + seq_len]) - else: - mrope_position_ids.append( - torch.tensor([his_len], device=delta.device).expand(3, -1) - + delta) - - mrope_position_ids = torch.cat(mrope_position_ids, dim=-1) - return mrope_position_ids - - def to_device(self, device: str): - """to device.""" - out_dict = dict() - for f in fields(self): - k = f.name - v = getattr(self, k) - if isinstance(v, torch.Tensor): - v = v.to(device) - elif isinstance(v, list): - v = [x.to(device) for x in v] - out_dict[k] = v - - return MRopeModelInputs(**out_dict) +def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'): + """broadcast tensor.""" + if value.device.type == 'meta': + value = torch.empty_like(value, device=device) + dist.broadcast(value, src) + return value @dataclass @@ -56,6 +28,7 @@ class VisionModelInputs: input_embeddings: List[List[torch.Tensor]] = None input_embedding_ranges: List[torch.LongTensor] = None input_embedding_indexing: torch.BoolTensor = None + input_multimodals: List[MultiModalTensor] = None def to_device(self, device: str): """to device.""" @@ -63,12 +36,54 @@ def to_device(self, device: str): for f in fields(self): k = f.name v = getattr(self, k) + if v is None: + continue if isinstance(v, torch.Tensor): v = v.to(device) - elif k == 'input_embedding_ranges' and v is not None: + elif k == 'input_embedding_ranges': v = [e.to(device) for e in v] - elif k == 'input_embeddings' and v is not None: + elif k == 'input_embeddings': v = [[e.to(device) for e in li] for li in v] + elif k == 'input_multimodals': + new_v = [] + for mm_datas in v: + new_mm_datas = dict() + for modal_type, data in mm_datas.items(): + data = [d.to_device(device) for d in data] + new_mm_datas[modal_type] = data + new_v.append(new_mm_datas) + v = new_v + out_dict[k] = v + + return VisionModelInputs(**out_dict) + + def broadcast(self): + """broadcast inputs. + + Do `dist.broadcast_object_list(inputs.to_device('meta'))` + before broadcast tensors. + """ + out_dict = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if v is None: + continue + if isinstance(v, torch.Tensor): + v = _broadcast_tensor(v) + elif k == 'input_embedding_ranges': + v = [_broadcast_tensor(e) for e in v] + elif k == 'input_embeddings': + v = [[_broadcast_tensor(e) for e in li] for li in v] + elif k == 'input_multimodals': + new_v = [] + for mm_datas in v: + new_mm_datas = dict() + for modal_type, data in mm_datas.items(): + data = [d.broadcast() for d in data] + new_mm_datas[modal_type] = data + new_v.append(new_mm_datas) + v = new_v out_dict[k] = v return VisionModelInputs(**out_dict) @@ -119,12 +134,12 @@ class ModelInputs: num_ignored_history: torch.LongTensor local_adapter_ids: torch.LongTensor = None vision_inputs: VisionModelInputs = None - mrope_inputs: MRopeModelInputs = None - cross_attention_states: torch.Tensor = None - history_cross_kv_seqlens: torch.LongTensor = None + cross_length: torch.LongTensor = None + history_cross_length: torch.LongTensor = None last_hidden_states: torch.Tensor = None medusa_attn_mask: torch.Tensor = None medusa_position_ids: torch.Tensor = None + model_metas: List[Dict[str, Any]] = None def update(self, input_ids: torch.LongTensor): """update input ids.""" @@ -135,44 +150,88 @@ def update(self, input_ids: torch.LongTensor): self.input_ids = input_ids return self - def split(self, split_size: int, block_size: int): + def split(self, split_size: int): """split inputs.""" assert len( self.seq_length) == 1, ('Can not perform split on batched input.') - assert split_size % block_size == 0, ( - 'split_size should be multi of block_size.') input_ids = self.input_ids if input_ids.numel() < split_size: return self - num_blocks = split_size // block_size - overlap = (self.history_lengths[0] % block_size != 0) + flatten_mms = [] + vision_inputs = self.vision_inputs + if vision_inputs is not None: + if vision_inputs.input_multimodals is not None: + input_mms = vision_inputs.input_multimodals[0] + + flatten_mms = [] + for k, mms in input_mms.items(): + mms = [(k, mm) for mm in mms] + flatten_mms += mms + + flatten_mms = sorted(flatten_mms, key=lambda mm: mm[1].start) + max_seq_len = self.seq_length[0].item() ret = [] - block_start = 0 - for i in range(0, max_seq_len, split_size): - start = i - end = min(max_seq_len, i + split_size) - block_end = block_start + num_blocks - if overlap: - block_end += 1 - - block_offsets = self.block_offsets + start = 0 + history_cross_length = self.history_cross_length + cross_length = None + if history_cross_length is not None: + cross_length = self.history_cross_length.clone() + while start < max_seq_len: + vision_inputs = None + if len(flatten_mms) > 0: + mm_start = flatten_mms[0][1].start + mm_end = flatten_mms[0][1].end + if mm_start > self.history_lengths + start: + end = min(mm_start - self.history_lengths, + start + split_size) + else: + input_mms = dict() + key, mm = flatten_mms.pop(0) + input_mms.setdefault(key, []) + input_mms[key].append(mm) + end = start + mm.end - mm.start + while len(flatten_mms) > 0: + next_mm = flatten_mms[0] + next_start = next_mm[1].start + next_end = next_mm[1].end + if next_start < mm_end: + key = next_mm[0] + input_mms.setdefault(key, []) + input_mms[key].append(next_mm[1]) + end += max(0, next_end - mm_end) + flatten_mms.pop(0) + + if cross_length is not None: + encoder_len = next_mm[1].encoder_len + if encoder_len is not None: + cross_length += encoder_len + else: + break + vision_inputs = VisionModelInputs( + input_multimodals=[input_mms], ) + else: + end = min(max_seq_len, start + split_size) + inp = ModelInputs( input_ids=self.input_ids[:, start:end], seq_length=input_ids.new_tensor([end - start]), - block_offsets=block_offsets, + block_offsets=self.block_offsets, history_lengths=self.history_lengths + start, is_decoding=self.is_decoding, num_ignored_history=self.num_ignored_history, local_adapter_ids=self.local_adapter_ids, - vision_inputs=self.vision_inputs, - mrope_inputs=self.mrope_inputs, - cross_attention_states=self.cross_attention_states, + vision_inputs=vision_inputs, + model_metas=self.model_metas, + cross_length=cross_length, + history_cross_length=history_cross_length, ) ret.append(inp) - block_start += num_blocks + history_cross_length = cross_length + + start = end return ret @@ -186,8 +245,24 @@ def to_device(self, device: str): v = v.to(device) elif isinstance(v, VisionModelInputs): v = v.to_device(device) - elif isinstance(v, MRopeModelInputs): - v = v.to_device(device) + out_dict[k] = v + + return ModelInputs(**out_dict) + + def broadcast(self): + """broadcast inputs. + + Do `dist.broadcast_object_list(inputs.to_device('meta'))` + before broadcast tensors. + """ + out_dict = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if isinstance(v, torch.Tensor): + v = _broadcast_tensor(v) + elif isinstance(v, VisionModelInputs): + v = v.broadcast() out_dict[k] = v return ModelInputs(**out_dict) @@ -201,6 +276,7 @@ class StepContext: dataclass provide these infos and tools. """ input_ids: torch.LongTensor + model_config: ModelConfig block_offsets: torch.LongTensor position_ids: torch.LongTensor attention_mask: torch.LongTensor @@ -213,15 +289,16 @@ class StepContext: local_adapter_ids: torch.LongTensor = None input_embeddings: torch.Tensor = None input_embedding_indexing: torch.Tensor = None + input_multimodals: List[MultiModalTensor] = None vision_inputs: VisionModelInputs = None - mrope_position_ids: torch.Tensor = None attn_metadata: Any = None - cross_attn_metadata: Any = None - cross_attention_states: torch.Tensor = None + cross_seqlens: torch.LongTensor = None cross_kv_seqlens: torch.LongTensor = None + cross_attn_metadata: Any = None kv_quant_policy: Literal[0, 4, 8] = 0 last_hidden_states: torch.Tensor = None medusa_attn_mask: torch.Tensor = None + model_metas: List[Dict[str, Any]] = None _outputs: Dict = field(default_factory=dict) @@ -229,6 +306,7 @@ class StepContext: def new( cls, inputs: ModelInputs, + model_config: ModelConfig, world_size: int = 1, kv_caches: List = None, kv_quant_policy: Literal[0, 4, 8] = 0, @@ -244,34 +322,38 @@ def new( history_seqlens = inputs.history_lengths device = q_seqlens.device + input_multimodals = None + if inputs.vision_inputs is not None: + input_multimodals = inputs.vision_inputs.input_multimodals + # for vlm input_embeddings, input_embedding_indexing = None, None if (inputs.vision_inputs is not None and inputs.vision_inputs.input_embeddings is not None): input_embeddings, input_embedding_indexing = \ inputs.vision_inputs.get_inputs(history_seqlens, q_seqlens) - # for mrope - mrope_position_ids = None - if inputs.mrope_inputs is not None: - mrope_position_ids = inputs.mrope_inputs.get_inputs( - history_seqlens, q_seqlens) # for speculative decoding last_hidden_states = inputs.last_hidden_states # kv_seqlens - cross_attention_states = inputs.cross_attention_states if inputs.is_decoding: attention_mask = torch.ones_like(q_seqlens)[:, None] - position_ids = history_seqlens.unsqueeze(-1) - cross_attention_states = None + position_ids = history_seqlens.unsqueeze(-1).clone() else: - max_q_seqlen = q_seqlens.max().item() + max_q_seqlen = q_seqlens.contiguous().max().item() mask_range = torch.arange(max_q_seqlen, device=device)[None, :] attention_mask = (mask_range < q_seqlens[:, None]).long() position_ids = attention_mask.long().cumsum(-1) - 1 position_ids += history_seqlens.unsqueeze(-1) q_start_loc = q_seqlens.cumsum(0) - q_seqlens + # cross + cross_seqlens = inputs.cross_length + cross_kv_seqlens = None + if inputs.cross_length is not None: + cross_kv_seqlens = (inputs.cross_length + + inputs.history_cross_length) + # position ids 1d position_ids = cls.get_position_ids_1d(position_ids, q_seqlens)[None] # seq_len + history_length @@ -283,10 +365,12 @@ def new( ret = StepContext( input_ids=inputs.input_ids, + model_config=model_config, block_offsets=inputs.block_offsets, position_ids=position_ids, input_embeddings=input_embeddings, input_embedding_indexing=input_embedding_indexing, + input_multimodals=input_multimodals, attention_mask=attention_mask, q_seqlens=q_seqlens, kv_seqlens=kv_seqlens, @@ -296,12 +380,12 @@ def new( world_size=world_size, local_adapter_ids=inputs.local_adapter_ids, vision_inputs=inputs.vision_inputs, - mrope_position_ids=mrope_position_ids, - cross_attention_states=cross_attention_states, last_hidden_states=last_hidden_states, medusa_attn_mask=inputs.medusa_attn_mask, - cross_kv_seqlens=inputs.history_cross_kv_seqlens, kv_quant_policy=kv_quant_policy, + model_metas=inputs.model_metas, + cross_seqlens=cross_seqlens, + cross_kv_seqlens=cross_kv_seqlens, ) ret = get_backend().update_step_context(ret) @@ -330,6 +414,7 @@ def __init__(self): @staticmethod def build_context( inputs: ModelInputs, + model_config: ModelConfig, world_size: int = 1, kv_caches: List = None, kv_quant_policy: Literal[0, 4, 8] = 0, @@ -337,6 +422,7 @@ def build_context( """build context.""" return StepContext.new( inputs, + model_config, world_size, kv_caches, kv_quant_policy, diff --git a/lmdeploy/pytorch/models/baichuan.py b/lmdeploy/pytorch/models/baichuan.py index 583cd19fe9..38d794f1be 100644 --- a/lmdeploy/pytorch/models/baichuan.py +++ b/lmdeploy/pytorch/models/baichuan.py @@ -228,7 +228,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -245,7 +244,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py index 8d7a21a0a6..5a83154167 100644 --- a/lmdeploy/pytorch/models/chatglm2.py +++ b/lmdeploy/pytorch/models/chatglm2.py @@ -1,101 +1,29 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn +from torch.nn import functional as F from transformers.configuration_utils import PretrainedConfig +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding) -from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, + build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixin LANGUAGE_TOKEN_TYPE = 0 VISION_TOKEN_TYPE = 1 -def get_vision_expert_mask(token_type_ids: torch.LongTensor): - vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool) - vision_token_mask[:, :-1] = (token_type_ids[:, :-1] - == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] - == VISION_TOKEN_TYPE) - language_token_mask = ~vision_token_mask - return vision_token_mask, language_token_mask - - -def build_position_ids(x: torch.BoolTensor) -> torch.LongTensor: - tmp = x.clone() - # image boi eoi token as LANGUAGE_TOKEN_TYPE - is_boi_eoi = torch.zeros_like(x, dtype=torch.bool) - is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & ( - tmp[:, :-1] == LANGUAGE_TOKEN_TYPE) - is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE) - is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & ( - tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) - is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE) - tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE - # final position ids - y = torch.zeros_like(x, dtype=torch.long) - y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ( - (tmp[:, 1:] == VISION_TOKEN_TYPE) & - (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)) - y = y.cumsum(dim=-1) - return y - - -def _get_cogvlm_position_ids(context): - """get cogvlm position_ids.""" - q_seqlens = context.q_seqlens - history_lengths = context.kv_seqlens - q_seqlens - vision_input_info = context.vision_inputs - position_id_offsets = (vision_input_info.history_image_token_lengths - - vision_input_info.history_image_nums * 3) - lang_ids = None - vis_ids = None - if context.is_decoding: - position_ids = history_lengths - position_id_offsets - else: - if vision_input_info.input_embeddings is not None and len( - vision_input_info.input_embeddings) > 0: - starts = history_lengths - vision_input_info.history_lengths - ends = starts + q_seqlens - token_type_ids = vision_input_info.input_embedding_indexing.to( - torch.int) - history_position_lengths = (vision_input_info.history_lengths - - position_id_offsets) - position_ids_all = (history_position_lengths[:, None] + - build_position_ids(token_type_ids)) - position_ids = torch.cat([ - pids[s:e] - for (pids, s, e) in zip(position_ids_all, starts, ends) - ]) - vision_token_mask_all, _ = get_vision_expert_mask(token_type_ids) - vision_token_mask = torch.cat([ - masks[s:e] - for (masks, s, e) in zip(vision_token_mask_all, starts, ends) - ]) - mask_indexing = torch.arange(vision_token_mask.shape[-1], - device=vision_token_mask.device) - vis_ids = mask_indexing[vision_token_mask] - lang_ids = mask_indexing[~vision_token_mask] - - else: - position_ids = context.attention_mask.long().cumsum(-1) - 1 - position_ids += (history_lengths - - position_id_offsets).unsqueeze(-1) - device = position_ids.device - position_ids_1d = [ - ids[:l] for ids, l in zip(position_ids.cpu(), q_seqlens.cpu()) - ] - position_ids = torch.cat(position_ids_1d).to(device) - - return position_ids, lang_ids, vis_ids - - class SelfAttention(torch.nn.Module): """Parallel self-attention layer abstract class. @@ -112,11 +40,10 @@ def __init__(self, self.projection_size = config.kv_channels * config.num_attention_heads self.num_attention_heads = config.num_attention_heads - self.num_kv_heads = self.num_attention_heads + self.num_kv_heads = config.num_key_value_heads self.head_size = (self.projection_size // config.num_attention_heads) - self.multi_query_attention = config.multi_query_attention - if self.multi_query_attention: - self.num_kv_heads = config.multi_query_group_num + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) self.query_key_value = build_qkv_proj( config.hidden_size, num_q_heads=self.num_attention_heads, @@ -126,7 +53,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) # apply rotary self.apply_rotary_pos_emb = ApplyRotaryEmb() @@ -338,7 +265,6 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() - quantization_config = getattr(config, 'quantization_config', None) self.num_layers = config.num_layers self.post_layer_norm = config.post_layer_norm @@ -353,7 +279,6 @@ def build_layer(layer_number): assert config.rmsnorm self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon, - quant_config=quantization_config, dtype=dtype, device=device) @@ -410,6 +335,286 @@ def forward(self, input_ids): return embeddings +class PatchEmbedding(nn.Module): + """vision embedding.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.proj = nn.Conv2d(config.in_channels, + config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + dtype=dtype, + device=device) + self.cls_embedding = nn.Parameter( + torch.empty(1, config.hidden_size, dtype=dtype, device=device)) + self.position_embedding = nn.Embedding(config.num_positions, + config.hidden_size, + dtype=dtype, + device=device) + + def forward(self, images): + """forward.""" + x = self.proj(images) + x = x.flatten(2).transpose(1, 2) + cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x += self.position_embedding.weight.unsqueeze(0) + return x + + +class EVA2CLIPAttention(nn.Module): + """vision attention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + hidden_size = config.hidden_size + num_heads = config.num_heads + head_dim = config.hidden_size // config.num_heads + self.scale = head_dim**-0.5 + + # packed qkv + self.query_key_value = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # o_proj + self.dense = build_rowwise_linear(hidden_size, + hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """forward.""" + # qkv proj + qkv_states = self.query_key_value(hidden_states) + q, k, v = self.query_key_value.split_qkv(qkv_states) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + + # o proj + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(-2, -1) + attn_output = self.dense(attn_output) + return attn_output + + +class EVA2CLIPMLP(nn.Module): + """vision MLP.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + + # gate up + quantization_config = getattr(config, 'quantization_config', None) + self.fc1 = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + if config.hidden_act in [ + 'gelu', 'gelu_fast', 'quick_gelu', 'gelu_python' + ]: + self.activation_fn = nn.GELU() + else: + self.activation_fn = ACT2FN[config.hidden_act] + + # down + self.fc2 = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """forward.""" + x = self.fc1(x) + x = self.activation_fn(x) + x = self.fc2(x) + return x + + +class EVA2CLIPTransformerLayer(nn.Module): + """vision trans layer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + self.attention = EVA2CLIPAttention(config, dtype=dtype, device=device) + self.mlp = EVA2CLIPMLP(config, dtype=dtype, device=device) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + + def forward(self, hidden_states): + """forward.""" + attention_input = hidden_states + attention_output = self.input_layernorm( + self.attention(attention_input)) + hidden_states = attention_input + attention_output + mlp_input = hidden_states + mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) + output = mlp_input + mlp_output + return output + + +class EVA2CLIPTransformer(nn.Module): + """vision transformer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layers = nn.ModuleList([ + EVA2CLIPTransformerLayer(config, dtype=dtype, device=device) + for _ in range(config.num_hidden_layers) + ]) + + def forward(self, hidden_states): + """forward.""" + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class GLU(nn.Module): + """GLU.""" + + def __init__(self, + config: PretrainedConfig, + in_features: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.linear_proj = nn.Linear(in_features, + config.hidden_size, + bias=False, + dtype=dtype, + device=device) + self.norm1 = nn.LayerNorm(config.hidden_size, + dtype=dtype, + device=device) + self.act1 = nn.GELU() + self.act2 = nn.functional.silu + self.dense_h_to_4h = nn.Linear(config.hidden_size, + config.ffn_hidden_size, + bias=False, + dtype=dtype, + device=device) + self.gate_proj = nn.Linear(config.hidden_size, + config.ffn_hidden_size, + bias=False, + dtype=dtype, + device=device) + self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, + config.hidden_size, + bias=False, + dtype=dtype, + device=device) + + def forward(self, x): + x = self.linear_proj(x) + x = self.act1(self.norm1(x)) + x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x) + x = self.dense_4h_to_h(x) + return x + + +class EVA2CLIPModel(nn.Module): + """vision model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from argparse import Namespace + vision_config = Namespace(**config.vision_config) + + self.patch_embedding = PatchEmbedding(vision_config, + dtype=dtype, + device=device) + self.transformer = EVA2CLIPTransformer(vision_config, + dtype=dtype, + device=device) + self.linear_proj = GLU(config, + in_features=config.hidden_size, + dtype=dtype, + device=device) + self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, + out_channels=config.hidden_size, + kernel_size=2, + stride=2, + dtype=dtype, + device=device) + self.boi = nn.Parameter( + torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device)) + self.eoi = nn.Parameter( + torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device)) + self.scaling_factor = vision_config.scaling_factor + + def forward(self, images): + """forward.""" + x = self.patch_embedding(images) + x = self.transformer(x) + + x = x[:, 1:] + + b, s, h = x.shape + grid_size = int(s**0.5) + x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2) + x = self.conv(x) + + x = x.flatten(2).transpose(1, 2) + x = self.linear_proj(x) + boi = self.boi.expand(x.shape[0], -1, -1) + eoi = self.eoi.expand(x.shape[0], -1, -1) + x = torch.cat((boi, x, eoi), dim=1) + x = x / self.scaling_factor + return x + + class ChatGLMModel(nn.Module): def __init__(self, @@ -442,19 +647,32 @@ def __init__(self, dtype=dtype, device=device) + self.vision = None + if hasattr(config, 'vision_config'): + self.vision = EVA2CLIPModel(config, dtype=dtype, device=device) + def forward( self, input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, attn_metadata: Any = None, + images: torch.Tensor = None, + image_mask: torch.Tensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, ): """forward.""" # token embedding if inputs_embeds is None: + images_features = None + if images is not None: + images_features = self.vision(images) + images_features = images_features.flatten(0, 1)[None] inputs_embeds = self.embedding(input_ids) + if images is not None: + inputs_embeds.masked_scatter_(image_mask[..., None], + images_features) hidden_states = inputs_embeds @@ -477,7 +695,8 @@ def get_input_embeddings(self): return self.embedding -class ChatGLMForConditionalGeneration(nn.Module, CudaGraphMixin): +class ChatGLMForConditionalGeneration(nn.Module, DeployModelMixin, + CudaGraphMixin): """rewrote model of LlamaForCausalLM.""" def __init__(self, @@ -491,12 +710,16 @@ def __init__(self, # build Model self.transformer = ChatGLMModel(config, dtype=dtype, device=device) + self.input_processor = ChatGLMInputProcessor(self.config, dtype) + def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, + images: torch.Tensor = None, + image_mask: torch.Tensor = None, inputs_embeds: torch.Tensor = None, **kwargs, ): @@ -506,6 +729,8 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + images=images, + image_mask=image_mask, inputs_embeds=inputs_embeds, ) return hidden_states @@ -529,8 +754,23 @@ def prepare_inputs_for_generation( input_ids = context.input_ids position_ids = context.position_ids attn_metadata = context.attn_metadata - if context.vision_inputs is not None: - position_ids = _get_cogvlm_position_ids(context)[0][None] + + images = None + image_mask = None + if context.input_multimodals is not None: + images = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + images = [data for im_data in images for data in im_data] + if len(images) != 0: + image_token_id = images[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + images = torch.stack([data.data for data in images]) + else: + images = None + image_mask = None # process vision embeddings vision_embeddings = context.input_embeddings @@ -548,9 +788,92 @@ def prepare_inputs_for_generation( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + images=images, + image_mask=image_mask, inputs_embeds=inputs_embeds, ) + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """update model meta.""" + model_metas = context.model_metas + if not hasattr(self.config, 'vision_config'): + return model_metas + + input_multimodals = context.input_multimodals + if input_multimodals is None: + input_imgs = [[] for _ in model_metas] + else: + input_imgs = [] + for mm in input_multimodals: + if mm is None: + input_imgs.append([]) + else: + input_imgs.append(mm.get('image', [])) + + config = self.config + image_size: int = config.vision_config['image_size'] + patch_size: int = config.vision_config['patch_size'] + vision_token_num = ((image_size // patch_size // 2) * + (image_size // patch_size // 2) + 2) + num_pad = vision_token_num - 3 + + batched_num_img_tokens = [] + new_model_metas = [] + for meta, imgs in zip(model_metas, input_imgs): + if meta is None: + num_img_tokens = 0 + else: + num_img_tokens = meta.get('num_img_tokens', 0) + + batched_num_img_tokens.append(num_img_tokens) + + num_img_tokens += num_pad * len(imgs) + new_model_metas.append(dict(num_img_tokens=num_img_tokens)) + + # prepare cogvlm position_ids + q_seqlens = context.q_seqlens + position_ids = context.position_ids + + if context.is_decoding or all(len(imgs) == 0 for imgs in input_imgs): + num_img_tokens = torch.tensor(batched_num_img_tokens, + device=position_ids.device) + position_ids -= num_img_tokens[None] + else: + batched_position_ids = position_ids[0].split(q_seqlens) + for pos_ids, num_img_tok, imgs in zip(batched_position_ids, + batched_num_img_tokens, + input_imgs): + pos_ids -= num_img_tok + if len(imgs) == 0: + continue + + seq_len = pos_ids.size(0) + start = pos_ids[0].cpu().item() + new_pos_ids = [] + + imgs = sorted(imgs, key=lambda img: img.start) + for img in imgs: + img_pad_pos = img.start + 1 - num_img_tok + num_pad = img.end - img.start - 2 + new_pos_ids += list(range(start, img_pad_pos)) + new_pos_ids += [img_pad_pos] * num_pad + start = img_pad_pos + 1 + num_img_tok += num_pad + + remain = seq_len - len(new_pos_ids) + new_pos_ids += list(range(start, start + remain)) + + new_pos_ids = pos_ids.new_tensor(new_pos_ids) + pos_ids[:] = new_pos_ids + + position_ids = torch.cat(batched_position_ids)[None] + context.position_ids = position_ids + + return new_model_metas + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" # modify from vllm @@ -558,7 +881,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if 'transformer.vision' in name: + if '.query_key_value' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) continue + if 'rotary_pos_emb.inv_freq' in name: continue if ('rotary_pos_emb.cos_cached' in name @@ -581,3 +914,53 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class ChatGLMInputProcessor(BaseModelInputProcessor): + """input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + + if hasattr(config, 'vision_config'): + vision_config = config.vision_config + self.image_size = vision_config['image_size'] + self.patch_size = vision_config['patch_size'] + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + self.vision_token_num = self.num_patches // 4 + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + offset = input_mm['offset'] + num_pad = input_mm['image_tokens'] + image_token_id = input_mm.get('image_token_id', 0) + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor( + data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index 6caf10df00..8010e5cead 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -1,20 +1,27 @@ # Copyright (c) OpenMMLab. All rights reserved. +from argparse import Namespace from typing import Any, Iterable, List, Optional, Tuple import torch import torch.distributed as dist +import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.distributed import get_world_rank +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding) -from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, + build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixin class VisionExpertAttention(nn.Module): @@ -28,8 +35,9 @@ def __init__(self, is_cogvlm2 = hasattr(config, 'num_multi_query_heads') quantization_config = getattr(config, 'quantization_config', None) num_heads = config.num_attention_heads - num_key_value_heads = getattr(config, 'num_multi_query_heads', - num_heads) + num_key_value_heads = getattr(config, 'num_key_value_heads', num_heads) + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) hidden_size = config.hidden_size head_dim = getattr(config, 'head_dim', hidden_size // num_heads) self.hidden_size = hidden_size @@ -46,7 +54,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) self.language_expert_query_key_value = build_qkv_proj( hidden_size, num_q_heads=num_heads, @@ -56,7 +64,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) # rotary embedding self.apply_rotary_pos_emb = ApplyRotaryEmb() @@ -322,6 +330,283 @@ def forward( return outputs +class PatchEmbedding(nn.Module): + """vision embedding.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.proj = nn.Conv2d(config.in_channels, + config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + dtype=dtype, + device=device) + self.cls_embedding = nn.Parameter( + torch.empty(1, config.hidden_size, dtype=dtype, device=device)) + self.position_embedding = nn.Embedding(config.num_positions, + config.hidden_size, + dtype=dtype, + device=device) + + def forward(self, images): + """forward.""" + x = self.proj(images) + x = x.flatten(2).transpose(1, 2) + cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x += self.position_embedding.weight.unsqueeze(0) + return x + + +class EVA2CLIPAttention(nn.Module): + """vision attention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + hidden_size = config.hidden_size + num_heads = config.num_heads + head_dim = config.hidden_size // config.num_heads + self.scale = head_dim**-0.5 + + # packed qkv + self.query_key_value = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # o_proj + self.dense = build_rowwise_linear(hidden_size, + hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """forward.""" + # qkv proj + qkv_states = self.query_key_value(hidden_states) + q, k, v = self.query_key_value.split_qkv(qkv_states) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + + # o proj + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(-2, -1) + attn_output = self.dense(attn_output) + return attn_output + + +class EVA2CLIPMLP(nn.Module): + """vision MLP.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + + # gate up + quantization_config = getattr(config, 'quantization_config', None) + self.fc1 = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + if config.hidden_act in [ + 'gelu', 'gelu_fast', 'quick_gelu', 'gelu_python' + ]: + self.activation_fn = nn.GELU() + else: + self.activation_fn = ACT2FN[config.hidden_act] + + # down + self.fc2 = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """forward.""" + x = self.fc1(x) + x = self.activation_fn(x) + x = self.fc2(x) + return x + + +class EVA2CLIPTransformerLayer(nn.Module): + """vision trans layer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + self.attention = EVA2CLIPAttention(config, dtype=dtype, device=device) + self.mlp = EVA2CLIPMLP(config, dtype=dtype, device=device) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + + def forward(self, hidden_states): + """forward.""" + attention_input = hidden_states + attention_output = self.input_layernorm( + self.attention(attention_input)) + hidden_states = attention_input + attention_output + mlp_input = hidden_states + mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) + output = mlp_input + mlp_output + return output + + +class EVA2CLIPTransformer(nn.Module): + """vision transformer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layers = nn.ModuleList([ + EVA2CLIPTransformerLayer(config, dtype=dtype, device=device) + for _ in range(config.num_hidden_layers) + ]) + + def forward(self, hidden_states): + """forward.""" + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class GLU(nn.Module): + """GLU.""" + + def __init__(self, + config: PretrainedConfig, + in_features: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.linear_proj = nn.Linear(in_features, + config.hidden_size, + bias=False, + dtype=dtype, + device=device) + self.norm1 = nn.LayerNorm(config.hidden_size, + dtype=dtype, + device=device) + self.act1 = nn.GELU() + self.act2 = nn.functional.silu + self.dense_h_to_4h = nn.Linear(config.hidden_size, + config.intermediate_size, + bias=False, + dtype=dtype, + device=device) + self.gate_proj = nn.Linear(config.hidden_size, + config.intermediate_size, + bias=False, + dtype=dtype, + device=device) + self.dense_4h_to_h = nn.Linear(config.intermediate_size, + config.hidden_size, + bias=False, + dtype=dtype, + device=device) + + def forward(self, x): + x = self.linear_proj(x) + x = self.act1(self.norm1(x)) + x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x) + x = self.dense_4h_to_h(x) + return x + + +class EVA2CLIPModel(nn.Module): + """vision model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + vision_config = Namespace(**config.vision_config) + + self.patch_embedding = PatchEmbedding(vision_config, + dtype=dtype, + device=device) + self.transformer = EVA2CLIPTransformer(vision_config, + dtype=dtype, + device=device) + self.linear_proj = GLU(config, + in_features=vision_config.hidden_size, + dtype=dtype, + device=device) + self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, + out_channels=vision_config.hidden_size, + kernel_size=2, + stride=2, + dtype=dtype, + device=device) + self.boi = nn.Parameter( + torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device)) + self.eoi = nn.Parameter( + torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device)) + + def forward(self, images): + """forward.""" + x = self.patch_embedding(images) + x = self.transformer(x) + + x = x[:, 1:] + + b, s, h = x.shape + grid_size = int(s**0.5) + x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2) + x = self.conv(x) + + x = x.flatten(2).transpose(1, 2) + x = self.linear_proj(x) + boi = self.boi.expand(x.shape[0], -1, -1) + eoi = self.eoi.expand(x.shape[0], -1, -1) + x = torch.cat((boi, x, eoi), dim=1) + return x + + class CogVLMModel(nn.Module): """model.""" @@ -332,7 +617,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -349,10 +633,12 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) + # vision model + self.vision = EVA2CLIPModel(config, dtype=dtype, device=device) + # build rotary embedding emb_type = RopeType.LinearScaling rope_dim = config.hidden_size // config.num_attention_heads @@ -371,6 +657,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, attn_metadata: Any = None, + images: torch.Tensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, lang_ids: torch.LongTensor = None, vision_ids: torch.LongTensor = None, @@ -379,7 +666,12 @@ def forward( # token embedding if inputs_embeds is None: + if images is not None: + images_features = self.vision(images) + inputs_embeds = self.embed_tokens(input_ids) + if vision_ids is not None: + inputs_embeds[0, vision_ids] = images_features.flatten(0, 1) hidden_states = inputs_embeds @@ -416,85 +708,7 @@ def get_input_embeddings(self): VISION_TOKEN_TYPE = 1 -def get_vision_expert_mask(token_type_ids: torch.LongTensor): - vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool) - vision_token_mask[:, :-1] = (token_type_ids[:, :-1] - == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] - == VISION_TOKEN_TYPE) - language_token_mask = ~vision_token_mask - return vision_token_mask, language_token_mask - - -def build_position_ids(x: torch.BoolTensor) -> torch.LongTensor: - tmp = x.clone() - # image boi eoi token as LANGUAGE_TOKEN_TYPE - is_boi_eoi = torch.zeros_like(x, dtype=torch.bool) - is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & ( - tmp[:, :-1] == LANGUAGE_TOKEN_TYPE) - is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE) - is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & ( - tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) - is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE) - tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE - # final position ids - y = torch.zeros_like(x, dtype=torch.long) - y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ( - (tmp[:, 1:] == VISION_TOKEN_TYPE) & - (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)) - y = y.cumsum(dim=-1) - return y - - -def _get_cogvlm_position_ids(context): - """get cogvlm position_ids.""" - q_seqlens = context.q_seqlens - history_lengths = context.kv_seqlens - q_seqlens - vision_input_info = context.vision_inputs - position_id_offsets = (vision_input_info.history_image_token_lengths - - vision_input_info.history_image_nums * 3) - lang_ids = None - vis_ids = None - if context.is_decoding: - position_ids = history_lengths - position_id_offsets - else: - if vision_input_info.input_embeddings is not None and len( - vision_input_info.input_embeddings) > 0: - starts = history_lengths - vision_input_info.history_lengths - ends = starts + q_seqlens - token_type_ids = vision_input_info.input_embedding_indexing.to( - torch.int) - history_position_lengths = (vision_input_info.history_lengths - - position_id_offsets) - position_ids_all = (history_position_lengths[:, None] + - build_position_ids(token_type_ids)) - position_ids = torch.cat([ - pids[s:e] - for (pids, s, e) in zip(position_ids_all, starts, ends) - ]) - vision_token_mask_all, _ = get_vision_expert_mask(token_type_ids) - vision_token_mask = torch.cat([ - masks[s:e] - for (masks, s, e) in zip(vision_token_mask_all, starts, ends) - ]) - mask_indexing = torch.arange(vision_token_mask.shape[-1], - device=vision_token_mask.device) - vis_ids = mask_indexing[vision_token_mask] - lang_ids = mask_indexing[~vision_token_mask] - - else: - position_ids = context.attention_mask.long().cumsum(-1) - 1 - position_ids += (history_lengths - - position_id_offsets).unsqueeze(-1) - device = position_ids.device - position_ids_1d = [ - ids[:l] for ids, l in zip(position_ids.cpu(), q_seqlens.cpu()) - ] - position_ids = torch.cat(position_ids_1d).to(device) - - return position_ids, lang_ids, vis_ids - - -class CogVLMForCausalLM(nn.Module, CudaGraphMixin): +class CogVLMForCausalLM(nn.Module, CudaGraphMixin, DeployModelMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -512,6 +726,8 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr + # preprocessor + self.input_processor = CogVLMInputProcessor(self.config, dtype) # build model self.model = CogVLMModel(config, dtype=dtype, device=device) # build lm_head @@ -527,6 +743,7 @@ def forward( position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, + images: torch.Tensor = None, inputs_embeds: torch.Tensor = None, lang_ids: torch.LongTensor = None, vision_ids: torch.LongTensor = None, @@ -538,6 +755,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + images=images, inputs_embeds=inputs_embeds, lang_ids=lang_ids, vision_ids=vision_ids, @@ -561,8 +779,36 @@ def prepare_inputs_for_generation( """prepare input.""" # get input_ids, position_ids and attention metadatas input_ids = context.input_ids - position_ids, lang_ids, vis_ids = _get_cogvlm_position_ids(context) - position_ids = position_ids[None] + + # position_ids, lang_ids, vis_ids = _get_cogvlm_position_ids(context) + position_ids = context.position_ids + lang_ids = None + vis_ids = None + + # vision inputs + images = None + if context.input_multimodals is not None: + images = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + images = [data for im_data in images for data in im_data] + if len(images) == 0: + images = None + + if images is not None: + image_token_id = images[0].meta['image_token_id'] + vis_mask = input_ids[0] == image_token_id + images = torch.stack([data.data for data in images]) + + # get lang_ids + vis_range = torch.arange(0, + input_ids.size(-1), + device=input_ids.device) + vis_ids = vis_range[vis_mask] + lang_ids = vis_range[~vis_mask] + attn_metadata = context.attn_metadata # process vision embeddings @@ -581,6 +827,7 @@ def prepare_inputs_for_generation( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + images=images, inputs_embeds=inputs_embeds, lang_ids=lang_ids, vision_ids=vis_ids, @@ -597,8 +844,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - if 'model.vision' in name: - continue if 'rotary_emb.inv_freq' in name: continue if ('rotary_emb.cos_cached' in name @@ -607,6 +852,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if self.config.tie_word_embeddings and 'lm_head.weight' in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: + if '.vision.' in name: + continue if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -620,6 +867,136 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): load_weight(param, q, shard_id='q') load_weight(param, k, shard_id='k') load_weight(param, v, shard_id='v') + elif '.query_key_value' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') else: param = params_dict[name] load_weight(param, loaded_weight) + + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """update model meta.""" + model_metas = context.model_metas + input_multimodals = context.input_multimodals + if input_multimodals is None: + input_imgs = [[] for _ in model_metas] + else: + input_imgs = [] + for mm in input_multimodals: + if mm is None: + input_imgs.append([]) + else: + input_imgs.append(mm.get('image', [])) + + config = self.config + image_size: int = config.vision_config['image_size'] + patch_size: int = config.vision_config['patch_size'] + vision_token_num = ((image_size // patch_size // 2) * + (image_size // patch_size // 2) + 2) + num_pad = vision_token_num - 3 + + batched_num_img_tokens = [] + new_model_metas = [] + for meta, imgs in zip(model_metas, input_imgs): + if meta is None: + num_img_tokens = 0 + else: + num_img_tokens = meta.get('num_img_tokens', 0) + + batched_num_img_tokens.append(num_img_tokens) + + num_img_tokens += num_pad * len(imgs) + new_model_metas.append(dict(num_img_tokens=num_img_tokens)) + + # prepare cogvlm position_ids + q_seqlens = context.q_seqlens + position_ids = context.position_ids + + if context.is_decoding or all(len(imgs) == 0 for imgs in input_imgs): + num_img_tokens = torch.tensor(batched_num_img_tokens, + device=position_ids.device) + position_ids -= num_img_tokens[None] + else: + batched_position_ids = position_ids[0].split(q_seqlens) + for pos_ids, num_img_tok, imgs in zip(batched_position_ids, + batched_num_img_tokens, + input_imgs): + pos_ids -= num_img_tok + if len(imgs) == 0: + continue + + seq_len = pos_ids.size(0) + start = pos_ids[0].cpu().item() + new_pos_ids = [] + + imgs = sorted(imgs, key=lambda img: img.start) + for img in imgs: + img_pad_pos = img.start + 1 - num_img_tok + num_pad = img.end - img.start - 2 + new_pos_ids += list(range(start, img_pad_pos)) + new_pos_ids += [img_pad_pos] * num_pad + start = img_pad_pos + 1 + num_img_tok += num_pad + + remain = seq_len - len(new_pos_ids) + new_pos_ids += list(range(start, start + remain)) + + new_pos_ids = pos_ids.new_tensor(new_pos_ids) + pos_ids[:] = new_pos_ids + + position_ids = torch.cat(batched_position_ids)[None] + context.position_ids = position_ids + + return new_model_metas + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class CogVLMInputProcessor(BaseModelInputProcessor): + """input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + image_size: int = config.vision_config['image_size'] + patch_size: int = config.vision_config['patch_size'] + self.vision_token_num = ((image_size // patch_size // 2) * + (image_size // patch_size // 2) + 2) + + def preprocess_input(self, + input_ids: List[int], + input_multimodals=None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + offset = input_mm['offset'] + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor( + data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/dbrx.py b/lmdeploy/pytorch/models/dbrx.py index e71ff17fe9..7e61fd317d 100644 --- a/lmdeploy/pytorch/models/dbrx.py +++ b/lmdeploy/pytorch/models/dbrx.py @@ -9,7 +9,7 @@ from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, LayerNorm, RopeType, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear -from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK +from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin @@ -165,7 +165,7 @@ def __init__(self, act_fn_name = ffn_act_fn.get('name', None) assert act_fn_name == 'silu' - self.mlp = FusedMoE( + self.mlp = build_fused_moe( hidden_size, ffn_hidden_size, moe_num_experts, @@ -522,7 +522,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if '.experts' in name: loaded_weight = loaded_weight.unflatten(0, (num_experts, -1)) if '.w1' in name: - name = name.replace('.w1', '.gate_up_weights') + name = name.replace('.w1', '.gate_up.weight') param = params_dict[name] for exp_id in range(num_experts): weight = loaded_weight[exp_id] @@ -531,7 +531,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_id=exp_id, shard_id='gate') elif '.v1' in name: - name = name.replace('.v1', '.gate_up_weights') + name = name.replace('.v1', '.gate_up.weight') param = params_dict[name] for exp_id in range(num_experts): weight = loaded_weight[exp_id] @@ -540,7 +540,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_id=exp_id, shard_id='up') elif '.w2' in name: - name = name.replace('.w2', '.down_weights') + name = name.replace('.w2', '.down.weight') param = params_dict[name] for exp_id in range(num_experts): weight = loaded_weight[exp_id].t() diff --git a/lmdeploy/pytorch/models/deepseek.py b/lmdeploy/pytorch/models/deepseek.py index 5742baeee5..09c0b74fcc 100644 --- a/lmdeploy/pytorch/models/deepseek.py +++ b/lmdeploy/pytorch/models/deepseek.py @@ -12,7 +12,7 @@ SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) -from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK +from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin @@ -135,7 +135,7 @@ def __init__(self, self.softmax_topk = SoftmaxTopK(self.top_k) - self.experts = FusedMoE( + self.experts = build_fused_moe( self.hidden_dim, self.ffn_dim, self.num_experts, @@ -265,12 +265,10 @@ def __init__(self, device=device) # build attention layer norm - self.post_attention_layernorm = RMSNorm( - config.hidden_size, - config.rms_norm_eps, - quant_config=quantization_config, - dtype=dtype, - device=device) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + dtype=dtype, + device=device) def forward( self, @@ -315,7 +313,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -332,7 +329,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) @@ -528,14 +524,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts = self.config.n_routed_experts expert_params_mapping = [] for exp_id in range(num_experts): - gate_param = ('.experts.gate_up_weights', - f'.experts.{exp_id}.gate_proj.weight', exp_id, - 'gate') - up_param = ('.experts.gate_up_weights', - f'.experts.{exp_id}.up_proj.weight', exp_id, 'up') - down_param = ('.experts.down_weights', - f'.experts.{exp_id}.down_proj.weight', exp_id, - 'down') + gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', + exp_id, 'gate') + up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', + exp_id, 'up') + down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', + exp_id, 'down') expert_params_mapping += [gate_param, up_param, down_param] params_dict = dict(self.named_parameters()) diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index 34debae229..b69ae6650d 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist +import torch.nn.functional as F from torch import nn from lmdeploy.pytorch.distributed import get_world_rank @@ -13,7 +14,7 @@ from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_rowwise_linear) -from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK +from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe from lmdeploy.pytorch.nn.rotary_embedding import YarnParameters from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -81,7 +82,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() - quantization_config = None + quantization_config = getattr(config, 'quantization_config', None) self.q_lora_rank = config.q_lora_rank self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -90,6 +91,9 @@ def __init__(self, self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) + num_key_value_heads = getattr(config, 'num_key_value_heads', 1) if self.q_lora_rank is None: self.q_proj = build_colwise_linear( @@ -99,6 +103,7 @@ def __init__(self, dtype=dtype, device=device, is_tp=True, + quant_config=quantization_config, ) else: self.q_a_proj = build_colwise_linear( @@ -108,6 +113,7 @@ def __init__(self, dtype=dtype, device=device, is_tp=False, + quant_config=quantization_config, ) self.q_a_layernorm = RMSNorm(config.q_lora_rank, 1e-6, @@ -121,6 +127,7 @@ def __init__(self, dtype=dtype, device=device, is_tp=True, + quant_config=quantization_config, ) self.kv_a_proj_with_mqa = build_colwise_linear( @@ -130,6 +137,7 @@ def __init__(self, dtype=dtype, device=device, is_tp=False, + quant_config=quantization_config, ) self.kv_a_layernorm = RMSNorm(config.kv_lora_rank, 1e-6, @@ -157,10 +165,9 @@ def __init__(self, self.num_heads, config.kv_lora_rank + self.qk_rope_head_dim, scale=self.softmax_scale, - num_kv_heads=1, + num_kv_heads=num_key_value_heads, v_head_size=config.kv_lora_rank, - replicate_kv=True, - ) + num_replicate_kv_heads=num_replicate_kv_heads) self.vc = DeepseekV2BMM(self.num_heads, config.kv_lora_rank, @@ -174,6 +181,7 @@ def __init__(self, dtype=dtype, device=device, is_tp=True, + quant_config=quantization_config, ) def _q_proj(self, hidden_states, num_heads: int, nope_size: int, @@ -270,6 +278,104 @@ def forward( return attn_output +class MoEGate(nn.Module): + """Deepseek Gate.""" + + def __init__(self, + config: Any, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.alpha = config.aux_loss_alpha + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + self.renormalize = self.top_k > 1 and self.norm_topk_prob + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim), + dtype=dtype, + device=device)) + if self.topk_method == 'noaux_tc': + self.e_score_correction_bias = nn.Parameter( + torch.empty((self.n_routed_experts, ), + dtype=dtype, + device=device)) + self.softmax_topk = SoftmaxTopK(self.top_k) + + def _compute_scores(self, logits: torch.Tensor): + """compute scores.""" + if self.scoring_func == 'softmax': + scores = logits.softmax(dim=-1, dtype=torch.float32) + elif self.scoring_func == 'sigmoid': + scores = logits.sigmoid() + else: + raise NotImplementedError('insupportable scoring function ' + f'for MoE gating: {self.scoring_func}') + return scores + + def forward(self, hidden_states: torch.Tensor): + """forward.""" + sequence_length, hidden_dim = hidden_states.shape + router_logits = F.linear(hidden_states, self.weight) + + if self.topk_method == 'greedy': + topk_weight, topk_idx = self.softmax_topk(router_logits) + elif self.topk_method == 'group_limited_greedy': + scores = self._compute_scores(router_logits) + grouped_logits = scores.unflatten(-1, (self.n_group, -1)) + group_scores = (grouped_logits.max(-1).values) + group_idx = torch.topk(group_scores, + k=self.topk_group, + dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + group_mask = ~group_mask.bool()[..., None] + grouped_logits = grouped_logits.masked_fill(group_mask, 0.0) + scores = grouped_logits.flatten(1, 2) + topk_weight, topk_idx = self.softmax_topk(scores) + elif self.topk_method == 'noaux_tc': + scores = self._compute_scores(router_logits) + scores_for_choice = scores.view( + sequence_length, -1) + self.e_score_correction_bias[None] + group_scores = (scores_for_choice.view( + sequence_length, self.n_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) # [n, n_group] + group_idx = torch.topk(group_scores, + k=self.topk_group, + dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = (group_mask.unsqueeze(-1).expand( + sequence_length, self.n_group, + self.n_routed_experts // self.n_group).reshape( + sequence_length, -1)) # [n, e] + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), + 0.0) # [n, e] + _, topk_idx = torch.topk(tmp_scores, + k=self.top_k, + dim=-1, + sorted=False) + topk_weight = scores.gather(1, topk_idx) + else: + raise RuntimeError(f'Unsupported topk_method: {self.topk_method}') + if not self.renormalize: + topk_weight = topk_weight * self.routed_scaling_factor + return topk_weight, topk_idx + + class DeepseekV2MoE(nn.Module): """Deepseek v2 MoE.""" @@ -278,6 +384,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() + quantization_config = getattr(config, 'quantization_config', None) self.hidden_dim = config.hidden_size self.ffn_dim = config.moe_intermediate_size self.num_experts = config.n_routed_experts @@ -289,18 +396,9 @@ def __init__(self, self.n_group = config.n_group self.topk_group = config.topk_group - self.gate = build_rowwise_linear( - self.hidden_dim, - self.num_experts, - bias=False, - dtype=dtype, - device=device, - is_tp=False, - ) - - self.softmax_topk = SoftmaxTopK(self.top_k) + self.gate = MoEGate(config, dtype=dtype, device=device) - self.experts = FusedMoE( + self.experts = build_fused_moe( self.hidden_dim, self.ffn_dim, self.num_experts, @@ -309,6 +407,7 @@ def __init__(self, dtype=dtype, device=device, all_reduce=False, + quant_config=quantization_config, ) self.shared_experts = None @@ -333,27 +432,8 @@ def forward(self, hidden_states: torch.Tensor): """forward.""" batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - router_logits = self.gate(hidden_states) + topk_weights, topk_ids = self.gate(hidden_states) - if self.topk_method == 'greedy': - topk_weights, topk_ids = self.softmax_topk(router_logits) - elif self.topk_method == 'group_limited_greedy': - grouped_logits = router_logits.unflatten(-1, (self.n_group, -1)) - group_scores = (grouped_logits.max(-1).values) - group_idx = torch.topk(group_scores, - k=self.topk_group, - dim=-1, - sorted=False)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - group_mask = ~group_mask.bool()[..., None] - grouped_logits = grouped_logits.masked_fill(group_mask, 0.0) - router_logits = grouped_logits.flatten(1, 2) - topk_weights, topk_ids = self.softmax_topk(router_logits) - else: - raise RuntimeError(f'Unsupported topk_method: {self.topk_method}') - if not self.renormalize: - topk_weights = topk_weights * self.routed_scaling_factor out_states = self.experts( hidden_states, topk_weights, @@ -450,12 +530,10 @@ def __init__(self, device=device) # build attention layer norm - self.post_attention_layernorm = RMSNorm( - config.hidden_size, - config.rms_norm_eps, - quant_config=quantization_config, - dtype=dtype, - device=device) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + dtype=dtype, + device=device) def forward( self, @@ -572,7 +650,6 @@ def forward( cos, sin = cos[0], sin[0] rotary_pos_emb = (cos, sin) for idx, decoder_layer in enumerate(self.layers): - past_key_value = past_key_values[idx] hidden_states, residual = decoder_layer( hidden_states, @@ -601,6 +678,8 @@ def __init__(self, device: torch.device = None): super().__init__() self.config = config + self.quantization_config = getattr(config, 'quantization_config', None) + self.dtype = dtype self.ctx_mgr = ctx_mgr self.model = DeepseekV2Model(config, dtype=dtype, device=device) # build lm_head @@ -609,6 +688,7 @@ def __init__(self, bias=False, dtype=dtype, device=device) + self._load_buffers = dict() def forward( self, @@ -692,40 +772,99 @@ def __update_pe(weight, head_dim: int, pe_dim_offset: int): weight = weight.flatten(0, 1) return weight + def __load_kcvc(name: str, weight: torch.Tensor): + """load kc and vc from weight.""" + config = self.config + v_head_dim = config.v_head_dim + qk_nope_head_dim = config.qk_nope_head_dim + w_kc, w_vc = weight.unflatten( + 0, (-1, qk_nope_head_dim + v_head_dim)).split( + [qk_nope_head_dim, v_head_dim], dim=1) + w_vc = w_vc.transpose(1, 2).contiguous() + kc_param_name = name.replace('.kv_b_proj', '.kc') + param_kc = params_dict[kc_param_name] + load_weight(param_kc, w_kc) + vc_param_name = name.replace('.kv_b_proj', '.vc') + param_vc = params_dict[vc_param_name] + load_weight(param_vc, w_vc) + + def __dequant_weight(weight: torch.Tensor, scale: torch.Tensor, + dtype: torch.dtype): + """dequant weight.""" + dim_w0, dim_w1 = weight.shape + dim_s0, dim_s1 = scale.shape + assert dim_w0 % dim_s0 == 0 + assert dim_w1 % dim_s1 == 0 + group0 = dim_w0 // dim_s0 + group1 = dim_w1 // dim_s1 + weight = weight.reshape(dim_s0, group0, dim_s1, group1) + scale = scale.reshape(dim_s0, 1, dim_s1, 1) + weight = weight.to(scale.dtype) * scale + weight = weight.to(dtype) + weight = weight.reshape(dim_w0, dim_w1) + return weight + + def __load_kcvc_blocked_fp8(name: str, loaded_weight: torch.Tensor): + """dequant weight.""" + if name.endswith('.weight'): + weight_name = name + scale_name = name.replace('.weight', '.scale') + elif name.endswith('.scale'): + weight_name = name.replace('.scale', '.weight') + scale_name = name + self._load_buffers[name] = loaded_weight + if (weight_name in self._load_buffers + and scale_name in self._load_buffers): + weight = self._load_buffers.pop(weight_name) + scale = self._load_buffers.pop(scale_name) + kc_param_name = weight_name.replace('.kv_b_proj', '.kc') + dtype = params_dict[kc_param_name].dtype + weight = __dequant_weight(weight, scale, dtype) + __load_kcvc(weight_name, weight) + for (mod_name, head_dim, pe_dim_offset) in update_pe_mapping: if mod_name not in name: continue - weight = __update_pe(loaded_weight, head_dim, pe_dim_offset) + if name.endswith('.scale'): + weight = loaded_weight + else: + weight = __update_pe(loaded_weight, head_dim, pe_dim_offset) param = params_dict[name] load_weight(param, weight) break else: if '.kv_b_proj' in name: - config = self.config - v_head_dim = config.v_head_dim - qk_nope_head_dim = config.qk_nope_head_dim - w_kc, w_vc = loaded_weight.unflatten( - 0, (-1, qk_nope_head_dim + v_head_dim)).split( - [qk_nope_head_dim, v_head_dim], dim=1) - w_vc = w_vc.transpose(1, 2).contiguous() - kc_param_name = name.replace('.kv_b_proj', '.kc') - param_kc = params_dict[kc_param_name] - load_weight(param_kc, w_kc) - vc_param_name = name.replace('.kv_b_proj', '.vc') - param_vc = params_dict[vc_param_name] - load_weight(param_vc, w_vc) + quantization_config = self.quantization_config + quant_method = None + if quantization_config is not None: + quant_method = quantization_config.get('quant_method') + + if quant_method == 'fp8': + # update blocked fp8 weight + __load_kcvc_blocked_fp8(name, loaded_weight) + else: + __load_kcvc(name, loaded_weight) else: param = params_dict[name] load_weight(param, loaded_weight) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" + + def __skip_nextn(name, nextn_keys): + for nextn_key in nextn_keys: + if nextn_key in name: + return True + return False + stacked_params_mapping = [ # (param_name, shard_name, shard_id) ('.gate_up_proj', '.gate_proj', 0), ('.gate_up_proj', '.up_proj', 1), ] + scale_suffix = '.weight_scale_inv' + config = self.config qk_rope_head_dim = config.qk_rope_head_dim kv_lora_rank = config.kv_lora_rank @@ -739,16 +878,23 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts = self.config.n_routed_experts expert_params_mapping = [] for exp_id in range(num_experts): - gate_param = ('.experts.gate_up_weights', - f'.experts.{exp_id}.gate_proj.weight', exp_id, - 'gate') - up_param = ('.experts.gate_up_weights', - f'.experts.{exp_id}.up_proj.weight', exp_id, 'up') - down_param = ('.experts.down_weights', - f'.experts.{exp_id}.down_proj.weight', exp_id, - 'down') + gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', + exp_id, 'gate') + up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', + exp_id, 'up') + down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', + exp_id, 'down') expert_params_mapping += [gate_param, up_param, down_param] + num_hidden_layers = self.config.num_hidden_layers + + num_nextn_predict_layers = getattr(self.config, + 'num_nextn_predict_layers', 1) + nextn_keys = [ + f'.layers.{num_hidden_layers+i}' + for i in range(num_nextn_predict_layers) + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if 'rotary_emb.inv_freq' in name: @@ -756,8 +902,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): continue + if '.layers' in name: + # skip nextn + if __skip_nextn(name, nextn_keys): + continue if self.config.tie_word_embeddings and 'lm_head.weight' in name: continue + if name.endswith(scale_suffix): + name = name[:-len(scale_suffix)] + '.scale' if '.experts' in name: self._load_weight_experts( name, diff --git a/lmdeploy/pytorch/models/falcon.py b/lmdeploy/pytorch/models/falcon.py index 8f8659dc5e..2d2edb9f49 100644 --- a/lmdeploy/pytorch/models/falcon.py +++ b/lmdeploy/pytorch/models/falcon.py @@ -31,34 +31,31 @@ def __init__(self, self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads - self.num_kv_heads = self.num_attention_heads + self.num_kv_heads = getattr(config, 'num_kv_heads', + config.num_attention_heads) + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) self.head_size = (self.hidden_size // config.num_attention_heads) - self.multi_query_attention = config.multi_query - if self.multi_query_attention: - self.num_kv_heads = 1 self.query_key_value = build_qkv_proj( config.hidden_size, num_q_heads=self.num_attention_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, bias=config.bias, - replicate_kv=self.multi_query_attention, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) # apply rotary self.apply_rotary_pos_emb = ApplyRotaryEmb() self.rotary = config.rotary # attention - self.attn_fwd = Attention( - self.num_attention_heads, - self.head_size, - num_kv_heads=self.num_kv_heads, - alibi=config.alibi, - ) + self.attn_fwd = Attention(self.num_attention_heads, + self.head_size, + num_kv_heads=self.num_kv_heads, + alibi=config.alibi) # o_proj self.dense = build_rowwise_linear(self.hidden_size, diff --git a/lmdeploy/pytorch/models/gemma.py b/lmdeploy/pytorch/models/gemma.py index ca36f15651..1f24206b16 100644 --- a/lmdeploy/pytorch/models/gemma.py +++ b/lmdeploy/pytorch/models/gemma.py @@ -31,7 +31,8 @@ def __init__(self, num_key_value_heads = config.num_key_value_heads hidden_size = config.hidden_size head_dim = config.head_dim - + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) # packed qkv self.qkv_proj = build_qkv_proj( hidden_size, @@ -42,7 +43,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) # rotary embedding self.apply_rotary_pos_emb = ApplyRotaryEmb() @@ -262,7 +263,6 @@ def __init__(self, self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -279,7 +279,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/internlm.py b/lmdeploy/pytorch/models/internlm.py index 99c622e4ac..fdee716b4a 100644 --- a/lmdeploy/pytorch/models/internlm.py +++ b/lmdeploy/pytorch/models/internlm.py @@ -28,7 +28,8 @@ def __init__(self, num_key_value_heads = config.num_key_value_heads hidden_size = config.hidden_size head_dim = getattr(config, 'head_dim', hidden_size // num_heads) - + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) # packed qkv self.qkv_proj = build_qkv_proj( hidden_size, @@ -39,7 +40,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) # rotary embedding self.apply_rotary_pos_emb = ApplyRotaryEmb() diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index 6cbc2ccff3..52f51a3ad1 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -28,7 +28,8 @@ def __init__(self, num_key_value_heads = config.num_key_value_heads hidden_size = config.hidden_size head_dim = hidden_size // num_heads - + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) # packed qkv self.wqkv = build_qkv_proj( hidden_size, @@ -39,6 +40,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, + num_replicate_kv_heads=num_replicate_kv_heads, ) # rotary embedding @@ -219,7 +221,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, @@ -239,7 +240,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) @@ -395,6 +395,32 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, ) + def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], + adapter_id: int): + """load lora weights.""" + + from lmdeploy.pytorch.adapter.adapter import load_lora_weights + + num_heads = self.config.num_attention_heads + num_key_value_heads = self.config.num_key_value_heads + hidden_size = self.config.hidden_size + head_dim = hidden_size // num_heads + group_size = num_heads // num_key_value_heads + + def _rearange_wqkv(weights): + for name, loaded_weight in weights: + if 'wqkv.lora_B' in name: + loaded_weight = loaded_weight.unflatten( + 0, (-1, 2 + group_size, head_dim)) + q = loaded_weight[:, :-2].flatten(0, 2) + k = loaded_weight[:, -2].flatten(0, 1) + v = loaded_weight[:, -1].flatten(0, 1) + loaded_weight = torch.cat([q, k, v], dim=0) + yield name, loaded_weight + + weights_iter = _rearange_wqkv(weights) + load_lora_weights(self, weights_iter, adapter_id) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" # modify from vllm diff --git a/lmdeploy/pytorch/models/internlm2_ve.py b/lmdeploy/pytorch/models/internlm2_ve.py index b1a2329597..c10faa5f5d 100644 --- a/lmdeploy/pytorch/models/internlm2_ve.py +++ b/lmdeploy/pytorch/models/internlm2_ve.py @@ -105,7 +105,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, @@ -125,7 +124,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 70dd8f2159..5fccd627e5 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -1,17 +1,315 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Iterable, List, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch +import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn import LayerNorm, RMSNorm +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_qkv_proj, + build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .patch import build_model_from_hf_config from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixin -class InternVLChatModel(nn.Module, CudaGraphMixin): +class InternVisionEmbeddings(nn.Module): + """intern vision embedding.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device), ) + + self.patch_embedding = nn.Conv2d(in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + dtype=dtype, + device=device) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter( + torch.empty(1, + self.num_positions, + self.embed_dim, + dtype=dtype, + device=device)) + + def _get_pos_embed(self, pos_embed, H, W): + target_dtype = pos_embed.dtype + pos_embed = pos_embed.float().reshape( + 1, self.image_size // self.patch_size, + self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, + size=(H, W), + mode='bicubic', + align_corners=False).reshape( + 1, -1, H * W).permute(0, 2, + 1).to(target_dtype) + return pos_embed + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values) # shape = [*, channel, width, height] + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, + -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = torch.cat([ + self.position_embedding[:, :1, :], + self._get_pos_embed(self.position_embedding[:, 1:, :], height, + width) + ], + dim=1) + embeddings = embeddings + position_embedding.to(target_dtype) + return embeddings + + +NORM2FN = { + 'rms_norm': RMSNorm, + 'layer_norm': LayerNorm, +} + + +class InternAttention(nn.Module): + """intern vl attention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.qkv = build_qkv_proj( + self.embed_dim, + num_q_heads=self.num_heads, + num_kv_heads=self.num_heads, + head_size=self.head_dim, + bias=config.qkv_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = RMSNorm( + self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device, + tp=True, + align=self.head_dim, + ) + self.k_norm = RMSNorm( + self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device, + tp=True, + align=self.head_dim, + ) + + self.scale = self.head_dim**-0.5 + + # o_proj + self.proj = build_rowwise_linear(self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True, + tp_align_size=self.head_dim) + + def forward(self, hidden_states): + """forward.""" + + # qkv proj + qkv_states = self.qkv(hidden_states) + q, k, v = self.qkv.split_qkv(qkv_states) + + if self.qk_normalization: + q_shape = q.shape + q = self.q_norm(q.flatten(-2, -1)).view(q_shape) + k = self.k_norm(k.flatten(-2, -1)).view(q_shape) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + + # o proj + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(-2, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class InternMLP(nn.Module): + """intern vl mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + self.act = ACT2FN[config.hidden_act] + + self.fc1 = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + self.fc2 = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class InternVisionEncoderLayer(nn.Module): + """intern vision encoder layer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.intermediate_size = config.intermediate_size + self.norm_type = getattr(config, 'norm_type', 'rms_norm') + + self.attn = InternAttention(config, dtype=dtype, device=device) + self.mlp = InternMLP(config, dtype=dtype, device=device) + self.norm1 = NORM2FN[self.norm_type](self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + self.norm2 = NORM2FN[self.norm_type](self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + + self.ls1 = nn.Parameter( + torch.empty(self.embed_dim, dtype=dtype, device=device)) + self.ls2 = nn.Parameter( + torch.empty(self.embed_dim, dtype=dtype, device=device)) + + def forward( + self, + hidden_states: torch.Tensor, + ): + """forward.""" + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1 + + hidden_states = hidden_states + self.mlp( + self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2 + + return hidden_states + + +class InternVisionEncoder(nn.Module): + """intern vision encoder.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + InternVisionEncoderLayer(config, dtype=dtype, device=device) + for idx in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds, + ): + """forward.""" + hidden_states = inputs_embeds + for _, encoder_layer in enumerate(self.layers): + layer_outputs = encoder_layer(hidden_states, ) + hidden_states = layer_outputs + return hidden_states + + +class InternVisionModel(nn.Module): + """intern vision model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + + self.embeddings = InternVisionEmbeddings(config, + dtype=dtype, + device=device) + self.encoder = InternVisionEncoder(config, dtype=dtype, device=device) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + ): + """forward.""" + assert pixel_values.dim() == 4 + hidden_states = self.embeddings(pixel_values) + + encoder_outputs = self.encoder(inputs_embeds=hidden_states) + last_hidden_state = encoder_outputs + + return last_hidden_state + + +class InternVLChatModel(nn.Module, DeployModelMixin, CudaGraphMixin): def __init__(self, config: PretrainedConfig, @@ -21,31 +319,106 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr + self.select_layer = config.select_layer + llm_config = config.llm_config + self.llm_arch_name = llm_config.architectures[0] + self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM' + + vision_config = config.vision_config + if self.is_mono: + from .internvl_patch import InternVisionPatchModel + self.vision_model = InternVisionPatchModel( + vision_config, + dtype=dtype, + device=device, + ) + else: + self.vision_model = InternVisionModel(vision_config, + dtype=dtype, + device=device) + self.language_model = build_model_from_hf_config(llm_config, dtype=dtype, device=device) - self.llm_arch_name = llm_config.architectures[0] + vit_hidden_size = config.vision_config.hidden_size + llm_hidden_size = config.llm_config.hidden_size + self.downsample_ratio = config.downsample_ratio + self.mlp1 = nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2, + dtype=dtype, + device=device), + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, + llm_hidden_size, + dtype=dtype, + device=device), nn.GELU(), + nn.Linear(llm_hidden_size, + llm_hidden_size, + dtype=dtype, + device=device)) # for Mono-InternVL - self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM' if self.is_mono: assert dtype != torch.float16, ( 'Currently Mono-InternVL does not support FP16 due to' 'numerical instability. Please use BF16 instead.') + self.input_processor = InternVLInputProcessor(self.config, dtype) + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> + # N, H * scale, W * scale, C // (scale ** 2) + x = x.view(n, int(h * scale_factor), int(w * scale_factor), + int(c / (scale_factor * scale_factor))) + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values): + """extract vision feature.""" + assert self.select_layer == -1 + vit_embeds = self.vision_model(pixel_values) + if self.is_mono: + if int(vit_embeds.shape[1]**0.5)**2 != vit_embeds.shape[1]: + vit_embeds = vit_embeds[:, 1:, :] + else: + vit_embeds = vit_embeds[:, 1:, :] + + h = w = int(vit_embeds.shape[1]**0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, + scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, + vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, + pixel_values: torch.Tensor = None, + image_mask: torch.Tensor = None, inputs_embeds: torch.Tensor = None, vision_embedding_indexing: torch.Tensor = None, text_embedding_indexing: torch.Tensor = None, **kwargs, ): + if inputs_embeds is None and pixel_values is not None: + # extract feature + vit_embeds = self.extract_feature(pixel_values) + lang_embeds = self.language_model.get_input_embeddings()(input_ids) + lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) + + inputs_embeds = lang_embeds + if self.is_mono: return self.language_model.forward( input_ids=input_ids, @@ -80,11 +453,38 @@ def prepare_inputs_for_generation( input_ids = context.input_ids position_ids = context.position_ids attn_metadata = context.attn_metadata - # get inputs from context vision_embeddings = context.input_embeddings - vision_embedding_indexing = context.input_embedding_indexing + vision_embedding_indexing = None + + # vision inputs + pixel_values = None + image_mask = None + if context.input_multimodals is not None: + pixel_values = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + pixel_values = [ + data for im_data in pixel_values for data in im_data + ] + if len(pixel_values) > 0: + image_token_id = pixel_values[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + pixel_values = torch.cat([data.data for data in pixel_values]) + else: + pixel_values = None + image_mask = None + if self.is_mono and pixel_values is not None: + vision_embedding_indexing = torch.arange(input_ids.shape[1], + device=input_ids.device) + vision_embedding_indexing = vision_embedding_indexing[ + image_mask[0]] + + # get inputs from context if vision_embeddings is not None and len(vision_embeddings) > 0: + vision_embedding_indexing = context.input_embedding_indexing if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds[:, @@ -104,6 +504,8 @@ def prepare_inputs_for_generation( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_mask=image_mask, inputs_embeds=inputs_embeds, vision_embedding_indexing=vision_embedding_indexing, text_embedding_indexing=text_embedding_indexing, @@ -114,18 +516,96 @@ def prepare_inputs_for_generation( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_mask=image_mask, inputs_embeds=inputs_embeds, ) + def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], + adapter_id: int): + """load lora weights.""" + + if hasattr(self.language_model, 'load_lora_weights'): + return self.language_model.load_lora_weights(weights, adapter_id) + else: + from lmdeploy.pytorch.adapter.adapter import load_lora_weights + + return load_lora_weights(weights, adapter_id) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" - prefix_length = len('language_model.') + lang_prefix = 'language_model.' + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if name.startswith(lang_prefix): + continue + + if 'qkv' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + lang_prefix_length = len(lang_prefix) new_weights = dict() for key, val in weights: - if not key.startswith('language_model.'): + if not key.startswith(lang_prefix): continue - new_key = key[prefix_length:] + new_key = key[lang_prefix_length:] new_weights[new_key] = val self.language_model.load_weights(new_weights.items()) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class InternVLInputProcessor(BaseModelInputProcessor): + """internvl input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + + vision_config = config.vision_config + self.image_size = vision_config.image_size + self.patch_size = vision_config.patch_size + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + self.vision_token_num = self.num_patches // 4 + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + offset = input_mm['offset'] + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor( + data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/internvl_patch.py b/lmdeploy/pytorch/models/internvl_patch.py new file mode 100644 index 0000000000..d13ad2d39b --- /dev/null +++ b/lmdeploy/pytorch/models/internvl_patch.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.configuration_utils import PretrainedConfig + + +class InternVisionEmbeddings(nn.Module): + """mono vision.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device), ) + + self.patch_embedding = nn.Conv2d(in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + dtype=dtype, + device=device) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter( + torch.empty(1, + self.num_positions, + self.embed_dim, + dtype=dtype, + device=device)) + + def _get_pos_embed(self, pos_embed, H, W): + target_dtype = pos_embed.dtype + pos_embed = pos_embed.float().reshape( + 1, self.image_size // self.patch_size, + self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, + size=(H, W), + mode='bicubic', + align_corners=False) + pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2, + 1).to(target_dtype) + return pos_embed + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values) # shape = [*, channel, width, height] + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, + -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = torch.cat([ + self.position_embedding[:, :1, :], + self._get_pos_embed(self.position_embedding[:, 1:, :], height, + width) + ], + dim=1) + embeddings = embeddings + position_embedding.to(target_dtype) + return embeddings + + +class InternVisionPatchModel(nn.Module): + """mono vision.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.embeddings = InternVisionEmbeddings(config, + dtype=dtype, + device=device) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + ): + if len(pixel_values.shape) != 4: + raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') + + hidden_states = self.embeddings(pixel_values)[:, 1:] + return hidden_states diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index f38c5ef02b..1a98c02f03 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -29,7 +29,8 @@ def __init__(self, num_key_value_heads = config.num_key_value_heads hidden_size = config.hidden_size head_dim = getattr(config, 'head_dim', hidden_size // num_heads) - + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) # packed qkv self.qkv_proj = build_qkv_proj( hidden_size, @@ -40,6 +41,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, + num_replicate_kv_heads=num_replicate_kv_heads, ) # rotary embedding @@ -450,22 +452,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) - - -class LlavaLlamaForCausalLM(LlamaForCausalLM): - """llava llama for causallm.""" - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - """load weights.""" - - new_weights = dict() - for key, val in weights: - if key.startswith('model.vision_tower'): - continue - if key.startswith('model.mm_projector'): - continue - if key.startswith('model.image_newline'): - continue - new_weights[key] = val - - super().load_weights(new_weights.items()) diff --git a/lmdeploy/pytorch/models/llava.py b/lmdeploy/pytorch/models/llava.py index 56cb5ca675..751f7343ec 100644 --- a/lmdeploy/pytorch/models/llava.py +++ b/lmdeploy/pytorch/models/llava.py @@ -1,17 +1,443 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Iterable, List, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch +import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.models.llava.configuration_llava import LlavaConfig +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_qkv_proj, + build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .patch import build_model_from_hf_config from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixin -class LlavaForConditionalGeneration(nn.Module, CudaGraphMixin): +class LlavaMultiModalProjector(nn.Module): + + def __init__(self, + config: LlavaConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + + self.linear_1 = nn.Linear(config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=True, + dtype=dtype, + device=device) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear(config.text_config.hidden_size, + config.text_config.hidden_size, + bias=True, + dtype=dtype, + device=device) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class CLIPVisionEmbeddings(nn.Module): + """clip vision embedding.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + torch.empty(self.embed_dim, dtype=dtype, device=device)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + dtype=dtype, + device=device, + ) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding( + self.num_positions, + self.embed_dim, + dtype=dtype, + device=device, + ) + self.register_buffer('position_ids', + torch.arange(self.num_positions, + device=device).expand((1, -1)), + persistent=False) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, + width: int) -> torch.Tensor: + """This method allows to interpolate the pre-trained position + encodings, to be able to use the model on higher resolution images. + + This method is also adapted to support torch.jit tracing. + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing + # to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing( + ) and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + from transformers.utils import torch_int + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, + sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode='bicubic', + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size + or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f' ({self.image_size}*{self.image_size}).') + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding( + self.position_ids) + return embeddings + + +class CLIPAttention(nn.Module): + """clip attention.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.qkv_proj = build_qkv_proj( + self.embed_dim, + num_q_heads=self.num_heads, + num_kv_heads=self.num_heads, + head_size=self.head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + self.scale = self.head_dim**-0.5 + + # o_proj + self.out_proj = build_rowwise_linear(self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + ): + """forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + q, k, v = self.qkv_proj.split_qkv(qkv_states) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if attention_mask is not None and causal_attention_mask is not None: + attn_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attn_mask = causal_attention_mask + else: + attn_mask = attention_mask + + attn_output = F.scaled_dot_product_attention(q, + k, + v, + attn_mask=attn_mask, + scale=self.scale) + + # o proj + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(-2, -1) + attn_output = self.out_proj(attn_output) + return attn_output + + +class CLIPMLP(nn.Module): + """clip mlp.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + from transformers.activations import ACT2FN + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + self.fc2 = build_rowwise_linear( + config.intermediate_size, + config.hidden_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """forward.""" + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class CLIPEncoderLayer(nn.Module): + """clip encoder layer.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention(config, dtype=dtype, device=device) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + self.mlp = CLIPMLP(config, dtype=dtype, device=device) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + ): + """forward.""" + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class CLIPEncoder(nn.Module): + """clip encoder.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + CLIPEncoderLayer(config, dtype=dtype, device=device) + for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + vision_feature_layer: int = -1, + ): + """forward.""" + hidden_states = inputs_embeds + num_vision_layers = len(self.layers) + vision_feature_layer + 1 + for _, encoder_layer in enumerate(self.layers[:num_vision_layers]): + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask=causal_attention_mask, + ) + + hidden_states = layer_outputs + + return hidden_states + + +class CLIPVisionTransformer(nn.Module): + """clip vision transformer.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config, + dtype=dtype, + device=device) + self.pre_layrnorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + self.encoder = CLIPEncoder(config, dtype=dtype, device=device) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + + def forward( + self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + vision_feature_layer: int = -1, + ) -> BaseModelOutputWithPooling: + """forward.""" + hidden_states = self.embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + vision_feature_layer=vision_feature_layer) + + last_hidden_state = encoder_outputs + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=None, + attentions=None, + ) + + +class CLIPVisionModel(nn.Module): + """clip vision model.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.vision_model = CLIPVisionTransformer(config, + dtype=dtype, + device=device) + + def forward(self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + vision_feature_layer: int = -1, + **kwargs): + """forward.""" + return self.vision_model( + pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + vision_feature_layer=vision_feature_layer) + + +def build_vision_model(vision_config, + dtype: torch.dtype = None, + device: torch.device = None): + """build vision model.""" + model_type = vision_config.model_type + + if model_type == 'clip_vision_model': + return CLIPVisionModel(vision_config, dtype, device) + else: + raise NotImplementedError(f'<{model_type}> is not implemented.') + + +class LlavaForConditionalGeneration(nn.Module, CudaGraphMixin, + DeployModelMixin): def __init__(self, config: PretrainedConfig, @@ -22,19 +448,67 @@ def __init__(self, self.config = config self.ctx_mgr = ctx_mgr text_config = config.text_config + + self.vision_tower = build_vision_model(config.vision_config, + dtype=dtype, + device=device) + self.language_model = build_model_from_hf_config(text_config, dtype=dtype, device=device) + self.multi_modal_projector = LlavaMultiModalProjector(config, + dtype=dtype, + device=device) + + self.input_processor = LLavaInputProcessor(config, dtype) + + def get_image_features(self, + pixel_values, + vision_feature_layer: int = -1, + vision_feature_select_strategy: str = 'default'): + """get image features.""" + selected_image_feature = self.vision_tower( + pixel_values, vision_feature_layer=vision_feature_layer)[0] + if vision_feature_select_strategy == 'default': + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == 'full': + selected_image_feature = selected_image_feature + else: + raise ValueError( + f'Unexpected select feature strategy: {vision_feature_select_strategy}' # noqa: E501 + ) + image_features = self.multi_modal_projector(selected_image_feature) + image_features = image_features.flatten(0, 1)[None] + + return image_features + def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, + pixel_values: torch.Tensor = None, + image_mask: torch.Tensor = None, inputs_embeds: torch.Tensor = None, **kwargs, ): + if inputs_embeds is None: + image_features = None + if pixel_values is not None: + vision_feature_layer = self.config.vision_feature_layer + select_strategy = self.config.vision_feature_select_strategy + image_features = self.get_image_features( + pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=select_strategy) + inputs_embeds = self.language_model.get_input_embeddings()( + input_ids) + if pixel_values is not None: + inputs_embeds.masked_scatter_(image_mask[..., None], + image_features) + return self.language_model.forward(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, @@ -59,6 +533,27 @@ def prepare_inputs_for_generation( input_ids = context.input_ids position_ids = context.position_ids attn_metadata = context.attn_metadata + + # vision inputs + pixel_values = None + image_mask = None + if context.input_multimodals is not None: + pixel_values = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + pixel_values = [ + data for im_data in pixel_values for data in im_data + ] + if len(pixel_values) > 0: + image_token_id = pixel_values[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + pixel_values = torch.cat([data.data for data in pixel_values]) + else: + pixel_values = None + image_mask = None + # get inputs from context vision_embeddings = context.input_embeddings vision_embedding_indexing = context.input_embedding_indexing @@ -75,18 +570,404 @@ def prepare_inputs_for_generation( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_mask=image_mask, inputs_embeds=inputs_embeds, ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" - prefix_length = len('language_model.') + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ] + + # vis model + lang_prefix = 'language_model.' + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if name.startswith(lang_prefix): + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + # language model + prefix_length = len(lang_prefix) new_weights = dict() for key, val in weights: - if not key.startswith('language_model.'): + if not key.startswith(lang_prefix): continue new_key = key[prefix_length:] new_weights[new_key] = val self.language_model.load_weights(new_weights.items()) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class LLavaInputProcessor(BaseModelInputProcessor): + """llava input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + offset = input_mm['offset'] + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor( + data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + + from transformers.image_processing_utils import select_best_resolution + + if not isinstance(grid_pinpoints, list): + raise TypeError('grid_pinpoints should be a list of tuples or lists') + + if not isinstance(image_size, (list, tuple)): + image_size = image_size.tolist() + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def unpad_image(tensor, original_size): + """Unpads a PyTorch tensor of a padded and resized image.""" + if not isinstance(original_size, (list, tuple)): + original_size = original_size.tolist() + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(round(original_height * scale_factor, 7)) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding:current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(round(original_width * scale_factor, 7)) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding:current_width - padding] + + return unpadded_tensor + + +def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): + """Calculate the number of patches after the preprocessing for images of + any resolution.""" + from transformers.image_processing_utils import select_best_resolution + if not isinstance(grid_pinpoints, list): + raise TypeError('grid_pinpoints should be a list of tuples or lists') + + if not isinstance(image_size, (list, tuple)): + image_size = image_size.tolist() + + best_resolution = select_best_resolution(image_size, grid_pinpoints) + height, width = best_resolution + + num_patches = (height // patch_size) * (width // patch_size) + # add the base patch + num_patches += 1 + return num_patches + + +class LlavaNextForConditionalGeneration(LlavaForConditionalGeneration): + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__(config=config, + ctx_mgr=ctx_mgr, + dtype=dtype, + device=device) + self.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size, + dtype=dtype, + device=device)) + self.input_processor = LLavaNextInputProcessor(config, dtype) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: int, + vision_feature_select_strategy: str, + ): + # ! infer image_num_patches from image_sizes + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.config.image_grid_pinpoints, + patch_size=self.config.vision_config.image_size, + ) for imsize in image_sizes + ] + if pixel_values.dim() == 5: + # stacked if input is + # (batch_size, num_patches, num_channels, height, width) + _pixel_values_list = [ + pix_val[:num_patch] + for pix_val, num_patch in zip(pixel_values, image_num_patches) + ] + pixel_values = torch.cat(_pixel_values_list, dim=0) + elif pixel_values.dim() != 4: + # otherwise has to be stacked from list of + # (num_patches, num_channels, height, width) + raise ValueError(f'pixel_values of shape {pixel_values.shape}, ' + 'expect to be of 4 or 5 dimensions') + + selected_image_feature = self.vision_tower( + pixel_values, vision_feature_layer=vision_feature_layer)[0] + if vision_feature_select_strategy == 'default': + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == 'full': + selected_image_feature = selected_image_feature + image_features = self.multi_modal_projector(selected_image_feature) + image_features = torch.split(image_features, image_num_patches, dim=0) + return image_features + + def pack_image_features(self, + image_features, + image_sizes, + vision_feature_select_strategy, + image_newline=None): + + new_image_features = [] + feature_lens = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = (self.config.vision_config.image_size // + self.config.vision_config.patch_size) + + if vision_feature_select_strategy == 'default': + expected_num_patches = height * width + elif vision_feature_select_strategy == 'full': + expected_num_patches = height * width + 1 + if expected_num_patches != base_image_feature.shape[0]: + raise ValueError('The number of patches is ' + 'not consistent with the image size.') + + (num_patch_height, + num_patch_width) = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view(num_patch_height, + num_patch_width, height, + width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, + 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, + image_sizes[image_idx]) + if image_newline is not None: + image_feature = torch.cat( + ( + image_feature, + image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1).to( + image_feature.dtype), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), + dim=0) + else: + image_feature = image_feature[0] + if image_newline is not None: + image_feature = torch.cat( + (image_feature, image_newline[None].to(image_feature)), + dim=0) + new_image_features.append(image_feature) + feature_lens.append(image_feature.size(0)) + image_features = torch.cat(new_image_features, dim=0) + return image_features + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + pixel_values: torch.Tensor = None, + image_sizes: torch.Tensor = None, + image_mask: torch.Tensor = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + if inputs_embeds is None: + image_features = None + if pixel_values is not None: + vision_feature_layer = self.config.vision_feature_layer + select_strategy = self.config.vision_feature_select_strategy + image_sizes = image_sizes.tolist() + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=select_strategy) + image_features = self.pack_image_features( + image_features, + image_sizes, + vision_feature_select_strategy=select_strategy, + image_newline=self.image_newline, + ) + image_features = image_features[None] + inputs_embeds = self.language_model.get_input_embeddings()( + input_ids) + if pixel_values is not None: + inputs_embeds.masked_scatter_(image_mask[..., None], + image_features) + + return self.language_model.forward(input_ids=input_ids, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + position_ids=position_ids, + attn_metadata=attn_metadata) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare input.""" + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # vision inputs + pixel_values = None + image_sizes = None + image_mask = None + if context.input_multimodals is not None: + img_mms = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + img_mms = [data for im_data in img_mms for data in im_data] + if len(img_mms) > 0: + image_token_id = img_mms[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + pixel_values = torch.cat( + [data.data.flatten(0, 1) for data in img_mms]) + image_sizes = torch.cat( + [data.meta['image_sizes'] for data in img_mms]) + else: + pixel_values = None + image_sizes = None + + # get inputs from context + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_sizes=image_sizes, + image_mask=image_mask, + inputs_embeds=inputs_embeds, + ) + + +class LLavaNextInputProcessor(BaseModelInputProcessor): + """llava input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + image_sizes = input_mm['image_sizes'] + offset = input_mm['offset'] + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict( + image_sizes=image_sizes, + image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/medusa.py b/lmdeploy/pytorch/models/medusa.py index bc9d086dc9..28da3bad55 100644 --- a/lmdeploy/pytorch/models/medusa.py +++ b/lmdeploy/pytorch/models/medusa.py @@ -10,6 +10,7 @@ from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixin vicuna_7b_stage2 = [(0, ), (0, 0), (1, ), (0, 1), (0, 0, 0), (1, 0), (2, ), (0, 2), (0, 0, 1), (0, 3), (3, ), (0, 1, 0), (2, 0), (4, ), @@ -138,7 +139,7 @@ def forward(self, x): return x + self.act(self.linear(x)) -class MedusaModel(nn.Module, CudaGraphMixin): +class MedusaModel(nn.Module, CudaGraphMixin, DeployModelMixin): """The medusa model architecture.""" packed_modules_mapping = { diff --git a/lmdeploy/pytorch/models/minicpmv26.py b/lmdeploy/pytorch/models/minicpmv26.py index 725e97d9d7..9e47c56437 100644 --- a/lmdeploy/pytorch/models/minicpmv26.py +++ b/lmdeploy/pytorch/models/minicpmv26.py @@ -29,7 +29,8 @@ def __init__(self, num_key_value_heads = config.num_key_value_heads hidden_size = config.hidden_size head_dim = getattr(config, 'head_dim', hidden_size // num_heads) - + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) # packed qkv self.qkv_proj = build_qkv_proj( hidden_size, @@ -40,7 +41,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) # rotary embedding self.apply_rotary_pos_emb = ApplyRotaryEmb() @@ -226,7 +227,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -246,7 +246,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/mistral.py b/lmdeploy/pytorch/models/mistral.py index 04af4c8526..962cdb3d2b 100644 --- a/lmdeploy/pytorch/models/mistral.py +++ b/lmdeploy/pytorch/models/mistral.py @@ -223,7 +223,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -240,7 +239,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) @@ -420,22 +418,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) - - -class LlavaMistralForCausalLM(MistralForCausalLM): - """llava forcausallm.""" - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - """load weights.""" - - new_weights = dict() - for key, val in weights: - if key.startswith('model.vision_tower'): - continue - if key.startswith('model.mm_projector'): - continue - if key.startswith('model.image_newline'): - continue - new_weights[key] = val - - super().load_weights(new_weights.items()) diff --git a/lmdeploy/pytorch/models/mixtral.py b/lmdeploy/pytorch/models/mixtral.py index d98efee712..be414d7bff 100644 --- a/lmdeploy/pytorch/models/mixtral.py +++ b/lmdeploy/pytorch/models/mixtral.py @@ -8,7 +8,7 @@ from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear -from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK +from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin @@ -22,7 +22,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() - quantization_config = None + quantization_config = getattr(config, 'quantization_config', None) num_heads = config.num_attention_heads num_key_value_heads = config.num_key_value_heads @@ -112,6 +112,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() + quantization_config = getattr(config, 'quantization_config', None) self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts @@ -124,11 +125,12 @@ def __init__(self, dtype=dtype, device=device, is_tp=False, + quant_config=None, ) self.softmax_topk = SoftmaxTopK(self.top_k) - self.experts = FusedMoE( + self.experts = build_fused_moe( self.hidden_dim, self.ffn_dim, self.num_experts, @@ -137,6 +139,7 @@ def __init__(self, dtype=dtype, device=device, all_reduce=True, + quant_config=quantization_config, ) def forward(self, hidden_states: torch.Tensor): @@ -166,7 +169,7 @@ def __init__(self, device: torch.device = None): super().__init__() self.layer_idx = layer_idx - quantization_config = None + quantization_config = getattr(config, 'quantization_config', None) # build attention layer self.self_attn = MixtralAttention(config, dtype=dtype, device=device) @@ -182,12 +185,10 @@ def __init__(self, device=device) # build attention layer norm - self.post_attention_layernorm = RMSNorm( - config.hidden_size, - config.rms_norm_eps, - quant_config=quantization_config, - dtype=dtype, - device=device) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + dtype=dtype, + device=device) def forward( self, @@ -376,12 +377,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts = self.config.num_local_experts expert_params_mapping = [] for exp_id in range(num_experts): - gate_param = ('.experts.gate_up_weights', - f'.experts.{exp_id}.w1.weight', exp_id, 'gate') - up_param = ('.experts.gate_up_weights', - f'.experts.{exp_id}.w3.weight', exp_id, 'up') - down_param = ('.experts.down_weights', - f'.experts.{exp_id}.w2.weight', exp_id, 'down') + gate_param = ('.experts.gate_up', f'.experts.{exp_id}.w1', exp_id, + 'gate') + up_param = ('.experts.gate_up', f'.experts.{exp_id}.w3', exp_id, + 'up') + down_param = ('.experts.down', f'.experts.{exp_id}.w2', exp_id, + 'down') expert_params_mapping += [gate_param, up_param, down_param] params_dict = dict(self.named_parameters()) diff --git a/lmdeploy/pytorch/models/mllama.py b/lmdeploy/pytorch/models/mllama.py index 2596fe5299..15b3e9732b 100644 --- a/lmdeploy/pytorch/models/mllama.py +++ b/lmdeploy/pytorch/models/mllama.py @@ -3,23 +3,61 @@ import torch from torch import nn +from torch.nn import functional as F from transformers.models.llama import LlamaConfig -from transformers.models.mllama.modeling_mllama import MllamaTextConfig +from transformers.models.mllama.modeling_mllama import (MllamaTextConfig, + MllamaVisionConfig) +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, - SiluAndMul, build_rotary_embedding) -from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, LayerNorm, RMSNorm, + RopeType, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, + build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.nn.rotary_embedding import Llama3Parameters from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -from .utils.cudagraph import CudaGraphMixin +from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin, next_power_of_2 +from .utils.model import DeployModelMixin MLLAMA_IMAGE_TOKEN_ID = 128256 MLLAMA_IMAGE_TOKEN = '<|image|>' +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, + 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, + # max_num_tiles * target_length) + attention_mask = attention_mask.reshape(batch_size, + max_num_tiles * target_length, 1) + attention_mask = attention_mask * attention_mask.transpose( + -1, -2) * torch.finfo(dtype).min + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + class LlamaAttention(nn.Module): """Rewrite module of LlamaAttention.""" @@ -157,6 +195,7 @@ def __init__(self, self.head_dim, num_kv_heads=self.num_key_value_heads, v_head_size=self.head_dim, + causal=False, ) self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) @@ -579,7 +618,542 @@ def get_logits(self, hidden_states: torch.Tensor): return self.lm_head(hidden_states) -class MllamaForConditionalGeneration(nn.Module, CudaGraphMixin): +class MllamaPrecomputedPositionEmbedding(nn.Module): + """vis position embedding.""" + + def __init__(self, + config: MllamaVisionConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.config = config + self.num_patches = (config.image_size // config.patch_size)**2 + 1 + self.hidden_size = config.hidden_size + + self.gate = nn.Parameter(torch.empty(1, dtype=dtype, device=device)) + + # position embedding + self.embedding = nn.Parameter( + torch.empty(self.num_patches, + self.hidden_size, + dtype=dtype, + device=device)) + + # tile position embedding + self.tile_embedding = nn.Embedding(self.max_aspect_ratio_id + 1, + self.max_num_tiles * + self.num_patches * self.hidden_size, + dtype=dtype, + device=device) + + self._weight_inited = False + + def _init_weight(self): + """init weight.""" + if self._weight_inited: + return + + gate_tanh = self.gate.tanh() + gated_position_embedding = (1 - gate_tanh) * self.embedding + self.gate_tanh = gate_tanh + self.gated_position_embedding = gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size) + + self._weight_inited = True + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + """forward.""" + self._init_weight() + + # position embeddings + hidden_state = hidden_state + self.gated_position_embedding + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size) + gated_tile_position_embedding = (self.gate_tanh * + tile_position_embedding) + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + + def __init__(self, + config: MllamaVisionConfig, + is_gated: bool = True, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.is_gated = is_gated + + self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.hidden_size, + dtype=dtype, + device=device) + if is_gated: + self.gate = nn.Parameter(torch.empty(1, dtype=dtype, + device=device)) + + self._weight_inited = False + + def _init_weight(self): + """init weight.""" + if self._weight_inited: + return + + gate_tanh = self.gate.tanh() + self.gate_tanh = gate_tanh + + self._weight_inited = True + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + self._init_weight() + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, + self.hidden_size) + + if self.is_gated: + embeddings = embeddings * self.gate_tanh + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaVisionAttention(nn.Module): + """mllama vision attention.""" + + def __init__(self, + config: MllamaVisionConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + self.embed_dim = config.hidden_size + self.num_heads = config.attention_heads + self.head_dim = config.hidden_size // config.attention_heads + + # packed qkv + self.qkv_proj = build_qkv_proj( + self.embed_dim, + num_q_heads=self.num_heads, + num_kv_heads=self.num_heads, + head_size=self.head_dim, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # o_proj + self.o_proj = build_rowwise_linear(self.num_heads * self.head_dim, + self.embed_dim, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size = hidden_state.size(0) + qkv_states = self.qkv_proj(hidden_state) + qkv_states = qkv_states.flatten(0, -2) + query, key, value = self.qkv_proj.split_qkv(qkv_states) + + query = query.unflatten(0, (batch_size, -1)) + key = key.unflatten(0, (batch_size, -1)) + value = value.unflatten(0, (batch_size, -1)) + q_seq_len = query.shape[1] + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention(query, + key, + value, + attn_mask=attention_mask) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + + output = self.o_proj(attn_output) + + return output + + +class MllamaVisionMLP(nn.Module): + """mllama vision mlp.""" + + def __init__(self, + config: MllamaVisionConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + self.fc2 = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MllamaVisionEncoderLayer(nn.Module): + """vision encoder layer.""" + + def __init__(self, + config: MllamaVisionConfig, + is_gated: bool, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.hidden_size = config.hidden_size + self.is_gated = is_gated + self.self_attn = MllamaVisionAttention(config, + dtype=dtype, + device=device) + self.mlp = MllamaVisionMLP(config, dtype=dtype, device=device) + + self.input_layernorm = LayerNorm(self.hidden_size, + eps=config.norm_eps, + dtype=dtype, + device=device) + self.post_attention_layernorm = LayerNorm(self.hidden_size, + eps=config.norm_eps, + dtype=dtype, + device=device) + + if is_gated: + self.gate_attn = nn.Parameter( + torch.empty(1, dtype=dtype, device=device)) + self.gate_ffn = nn.Parameter( + torch.empty(1, dtype=dtype, device=device)) + + self._weight_inited = not is_gated + + def _init_weight(self): + """init weight.""" + if self._weight_inited: + return + + self.gate_attn_tanh = self.gate_attn.tanh() + self.gate_ffn_tanh = self.gate_ffn.tanh() + + self._weight_inited = True + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + """forward.""" + self._init_weight() + + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, + attention_mask=attention_mask) + if self.is_gated: + hidden_state = self.gate_attn_tanh * hidden_state + hidden_state = residual + hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + if self.is_gated: + hidden_state = self.gate_ffn_tanh * hidden_state + hidden_state = residual + hidden_state + + outputs = hidden_state + + return outputs + + +class MllamaVisionEncoder(nn.Module): + """vision encoder.""" + + def __init__(self, + config: MllamaVisionConfig, + num_layers=32, + is_gated=False, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + MllamaVisionEncoderLayer(config, + is_gated, + dtype=dtype, + device=device) for _ in range(num_layers) + ]) + self.gradient_checkpointing = False + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + """forward.""" + encoder_states = () + for encoder_layer in self.layers: + encoder_states = encoder_states + (hidden_states, ) + hidden_states = encoder_layer( + hidden_state=hidden_states, + attention_mask=attention_mask, + ) + encoder_states = encoder_states + (hidden_states, ) + + return hidden_states, encoder_states + + +class MllamaVisionModel(nn.Module): + """vision model.""" + + def __init__(self, + config: MllamaVisionConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + + self.config = config + self.image_size = config.image_size + self.patch_size = config.patch_size + self.hidden_size = config.hidden_size + self.intermediate_layers_indices = config.intermediate_layers_indices + self.dtype = dtype + + self.num_patches = (self.image_size // self.patch_size)**2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + padding='valid', + bias=False, + dtype=dtype, + device=device, + ) + + self.class_embedding = nn.Parameter( + torch.empty(self.hidden_size, dtype=dtype, device=device)) + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( + config, + dtype=dtype, + device=device, + ) + + self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( # noqa: E501 + config, + is_gated=True, + dtype=dtype, + device=device, + ) + self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( # noqa: E501 + config, + is_gated=True, + dtype=dtype, + device=device, + ) + + # layer norms + self.layernorm_pre = nn.LayerNorm( + self.hidden_size, + dtype=dtype, + device=device, + ) + self.layernorm_post = nn.LayerNorm( + self.hidden_size, + dtype=dtype, + device=device, + ) + + # encoders + self.transformer = MllamaVisionEncoder( + config, + config.num_hidden_layers, + is_gated=False, + dtype=dtype, + device=device, + ) + self.global_transformer = MllamaVisionEncoder( + config, + config.num_global_layers, + is_gated=True, + dtype=dtype, + device=device, + ) + + def apply_class_embedding(self, + hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, + hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor, + ): + """forward.""" + (batch_size, num_concurrent_media, num_tiles, num_channels, height, + width) = pixel_values.shape + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, + height, width) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1) + + # Patch embedding + patch_embeds = self.patch_embedding(pixel_values.to(self.dtype)) + hidden_state = patch_embeds.flatten(2).transpose(1, 2) + + # Tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, -1, dim) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + + # Add cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, num_patches, dim) + hidden_state = self.gated_positional_embedding(hidden_state, + aspect_ratio_ids) + + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, 0, 0, num_padding_patches + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode='constant', value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + # Prepare attention mask + attention_mask = aspect_ratio_mask.reshape( + batch_size * num_concurrent_media, -1) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.dtype, + ) + + # Apply encoder + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, + dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + ) + hidden_state = output[0] + + hidden_state = self.layernorm_post(hidden_state) + + # Apply global encoder + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), dim) + global_output = self.global_transformer( + hidden_state, + attention_mask=attention_mask, + ) + hidden_state = global_output[0] + + # Remove padding form hidden state + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = hidden_state[:, :, :slice_index] + hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, + num_tiles, num_patches, dim) + + # Collect intermediate layer outputs from encoder output + all_intermediate_hidden_states = output[1] + all_intermediate_hidden_states = [ + all_intermediate_hidden_states[i] + for i in self.intermediate_layers_indices + ] + intermediate_hidden_states = torch.stack( + all_intermediate_hidden_states, dim=-1) + + # Remove padding from intermediate hidden states + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, num_tiles, + num_patches + num_padding_patches, -1) + intermediate_hidden_states = intermediate_hidden_states[:, :, : + slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1) + + # Concatenate final hidden state and intermediate hidden states + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], + dim=-1) + + return hidden_state + + +class MllamaForConditionalGeneration(nn.Module, CudaGraphMixin, + DeployModelMixin): """rewrote model of MllamaForConditionalGeneration.""" packed_modules_mapping = { @@ -602,16 +1176,32 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr + + self.vision_model = MllamaVisionModel( + config.vision_config, + dtype=dtype, + device=device, + ) # build MllamaForCausalLM self.language_model = MllamaForCausalLM(config.text_config, dtype=dtype, device=device) + + self.multi_modal_projector = build_rowwise_linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=True, + dtype=dtype, + device=device, + ) self.dtype = dtype - def flat_encoder_result(self, cross_attention_states: torch.Tensor, - attn_metadata: Any, input_ids: torch.LongTensor): + # preprocessor + self.input_processor = MLlamaInputProcessor(self.config, dtype) + + def flat_encoder_result(self, attn_metadata: Any, + input_ids: torch.LongTensor): # since every state share the same shape - cross_attention_states = torch.cat(cross_attention_states, 0) full_text_row_masked_out_mask = torch.ones( (attn_metadata.q_seqlens.sum(), 1), dtype=torch.bool) start_pos = 0 @@ -621,39 +1211,51 @@ def flat_encoder_result(self, cross_attention_states: torch.Tensor, full_text_row_masked_out_mask[start_pos:img_id] = False start_pos += q_seq_len full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( - cross_attention_states.device) + input_ids.device) - return cross_attention_states, full_text_row_masked_out_mask + return full_text_row_masked_out_mask def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], - cross_attention_states: Optional[torch.Tensor] = None, + pixel_values: torch.Tensor = None, + aspect_ratio_ids: torch.Tensor = None, + aspect_ratio_mask: torch.Tensor = None, attn_metadata: Any = None, inputs_embeds: torch.Tensor = None, cross_attn_metadata: Any = None, **kwargs, ): """model forward, return logits.""" + if cross_attn_metadata is None: full_text_row_masked_out_mask = None # FIXME basically, we want to inference # text requests and image requests separately - elif cross_attention_states is None and ( - cross_attn_metadata.kv_seqlens is None - or int(cross_attn_metadata.kv_seqlens.sum()) == 0): + elif pixel_values is None and (cross_attn_metadata.kv_seqlens is None): full_text_row_masked_out_mask = None elif cross_attn_metadata.is_decoding: - cross_attention_states = None - full_text_row_masked_out_mask = torch.ones( - (attn_metadata.q_seqlens.sum(), 1), - dtype=torch.bool, - device=input_ids.device) + full_text_row_masked_out_mask = input_ids.new_ones( + input_ids.size(-1), 1) else: - cross_attention_states, full_text_row_masked_out_mask = \ - self.flat_encoder_result(cross_attention_states, cross_attn_metadata, input_ids) # noqa + full_text_row_masked_out_mask = self.flat_encoder_result( + cross_attn_metadata, input_ids) # noqa + + cross_attention_states = None + if pixel_values is not None: + cross_attention_states = self.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + ) + cross_attention_states = self.multi_modal_projector( + cross_attention_states) + _, bsz, _, _, image_token_dim = tuple(cross_attention_states.shape) + cross_attention_states = cross_attention_states.view( + bsz, -1, image_token_dim) + hidden_states = self.language_model( input_ids=input_ids, position_ids=position_ids, @@ -670,15 +1272,6 @@ def get_logits(self, hidden_states: torch.Tensor): """compute logits of the model output.""" return self.language_model.get_logits(hidden_states) - def support_cuda_graph( - self, - input_ids: torch.Tensor, - **kwargs, - ): - """support cudagraph.""" - - return False - def get_input_embeddings(self): """get input embeddings.""" return self.language_model.model.get_input_embeddings() @@ -694,14 +1287,35 @@ def prepare_inputs_for_generation( input_ids = context.input_ids position_ids = context.position_ids attn_metadata = context.attn_metadata - cross_attention_states = context.cross_attention_states - if cross_attention_states is not None: - cross_attention_states = [ - t.to(input_ids.device) for t in cross_attention_states - if t is not None - ] cross_attn_metadata = context.cross_attn_metadata + # cross_attn_metadata is None when inputs without image + if cross_attn_metadata is not None and int( + cross_attn_metadata.kv_seqlens.sum()) == 0: + cross_attn_metadata.kv_seqlens = None + + device = input_ids.device + + # process image input + pixel_values = None + aspect_ratio_ids = None + aspect_ratio_mask = None + if context.input_multimodals is not None: + pixel_values = [] + aspect_ratio_ids = [] + aspect_ratio_mask = [] + batched_image_data = [ + input_mm['image'] for input_mm in context.input_multimodals + ] + for image_data in batched_image_data: + for data in image_data: + pixel_values.append(data.data) + aspect_ratio_ids.append(data.meta['aspect_ratio_ids']) + aspect_ratio_mask.append(data.meta['aspect_ratio_mask']) + pixel_values = torch.cat(pixel_values, dim=0).to(device) + aspect_ratio_ids = torch.cat(aspect_ratio_ids, dim=0).to(device) + aspect_ratio_mask = torch.cat(aspect_ratio_mask, dim=0).to(device) + # process vision embeddings vision_embeddings = context.input_embeddings vision_embedding_indexing = context.input_embedding_indexing @@ -719,7 +1333,9 @@ def prepare_inputs_for_generation( past_key_values=past_key_values, attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, - cross_attention_states=cross_attention_states, + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, cross_attn_metadata=cross_attn_metadata, ) @@ -742,8 +1358,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): continue - if 'vision_model' in name or 'multi_modal_projector' in name: - continue if self.config.text_config.tie_word_embeddings and 'lm_head.weight' in name: # noqa continue for (param_name, weight_name, shard_id) in stacked_params_mapping: @@ -756,3 +1370,161 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) + + def support_cuda_graph( + self, + input_ids: torch.Tensor, + attn_metadata: Any, + cross_attn_metadata: Any, + **kwargs, + ): + """support cudagraph.""" + + if not attn_metadata.is_decoding: + return False + + if cross_attn_metadata is None: + return False + + if cross_attn_metadata.kv_seqlens is None: + return False + + return True + + def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """make cudagraph buffers from forward inputs.""" + input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, + **kwargs) + + device = graph_meta.device + max_batches = graph_meta.max_batchs + input_buffers['cross_kv_seqlens'] = torch.zeros(max_batches, + dtype=torch.int64, + device=device) + + return input_buffers + + def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """fill cudagraph buffers from forward inputs.""" + input_buffers = graph_meta.input_buffers + + new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, + **kwargs) + + attn_metadata = new_inputs['attn_metadata'] + cross_attn_metadata = new_inputs['cross_attn_metadata'] + block_offsets = attn_metadata.block_offsets + batch_size, _ = block_offsets.size() + + kv_seqlens = cross_attn_metadata.kv_seqlens + if kv_seqlens.data_ptr() != input_buffers['cross_kv_seqlens'].data_ptr( + ): + input_buffers['cross_kv_seqlens'].zero_() + input_buffers['cross_kv_seqlens'][:batch_size] = kv_seqlens + + new_batch_size = next_power_of_2(batch_size) + cross_attn_metadata.block_offsets = input_buffers[ + 'block_offsets'][:new_batch_size] + cross_attn_metadata.q_start_loc = input_buffers[ + 'q_start_loc'][:new_batch_size] + cross_attn_metadata.q_seqlens = input_buffers[ + 'q_seqlens'][:new_batch_size] + cross_attn_metadata.kv_seqlens = input_buffers[ + 'cross_kv_seqlens'][:new_batch_size] + + new_inputs['cross_attn_metadata'] = cross_attn_metadata + return new_inputs + + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """update model meta.""" + model_metas = context.model_metas + if model_metas is None: + batch_size = context.q_seqlens.size(0) + model_metas = [dict(cross_kv_len=0) for _ in range(batch_size)] + + if context.is_decoding: + return model_metas + + vision_inputs = context.vision_inputs + if vision_inputs is None: + return model_metas + + input_mms = vision_inputs.input_multimodals + if input_mms is None: + return model_metas + + config = self.config.vision_config + image_size = config.image_size + patch_size = config.patch_size + wh = image_size // patch_size + img_kv_len = wh * wh + 1 + img_kv_len = img_kv_len * 4 + + new_model_metas = [] + for idx, input_mm in enumerate(input_mms): + if input_mm is None: + new_model_metas.append(model_metas[idx]) + images = input_mm['image'] + num_img = len(images) + + cross_kv_len = 0 + if model_metas[idx] is not None: + cross_kv_len = model_metas[idx].get('cross_kv_len', + cross_kv_len) + cross_kv_len += img_kv_len * num_img + new_model_metas.append(dict(cross_kv_len=cross_kv_len)) + + return model_metas + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class MLlamaInputProcessor(BaseModelInputProcessor): + """mllama input processor.""" + + def __init__(self, config: LlamaConfig, dtype: torch.dtype) -> None: + self.config = config + self.dtype = dtype + + vision_config = self.config.vision_config + image_size = vision_config.image_size + patch_size = vision_config.patch_size + wh = image_size // patch_size + encoder_len = wh * wh + 1 + encoder_len = encoder_len * 4 + self.encoder_len = encoder_len + + def preprocess_input(self, input_ids, input_multimodals, **kwargs): + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'] + aspect_ratio_ids = input_mm['aspect_ratio_ids'] + aspect_ratio_mask = input_mm['aspect_ratio_mask'] + offset = input_mm['offset'] + + if pixel_values.dtype != self.dtype: + pixel_values = pixel_values.to(self.dtype) + + mm_data = MultiModalTensor( + data=pixel_values, + start=offset, + end=offset + 1, + encoder_len=self.encoder_len, + meta=dict(aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index af1b23cfc0..b47ff77b3a 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -82,17 +82,19 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM' }) +# deepseek-v3 +MODULE_MAP.update({ + 'DeepseekV3ForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM' +}) + # llava MODULE_MAP.update( { - 'LlavaLlamaForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlavaLlamaForCausalLM', - 'LlavaMistralForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.LlavaMistralForCausalLM', 'LlavaForConditionalGeneration': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration', # noqa: E501 'LlavaNextForConditionalGeneration': # noqa: E501 - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration' + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaNextForConditionalGeneration' # noqa: E501 }) # qwen @@ -158,7 +160,7 @@ # phi3 vision MODULE_MAP.update({ 'Phi3VForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.Phi3VForCausalLM', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3_v.Phi3VForCausalLM', }) # phi-3.5-moe diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py index 9da1b9f4ea..9604b19af5 100644 --- a/lmdeploy/pytorch/models/patch.py +++ b/lmdeploy/pytorch/models/patch.py @@ -8,6 +8,7 @@ import torch from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import load_state_dict from lmdeploy.utils import get_logger @@ -250,6 +251,10 @@ def add_adapters(model: torch.nn.Module, ranks, scalings = get_ranks_and_scalings(target_name, adapter_cfgs, device=device) + # split in case target_name has '.' like 'attention.wo' + # which cannot be used as name of a module + # and it's not aligned with key in model.packed_modules_mapping + target_name = target_name.split('.')[-1] found_mods, pack_idx = find_all_target(model, target_name) sum_rank = ranks.sum().item() @@ -295,7 +300,9 @@ def add_adapters(model: torch.nn.Module, for name, path in adapters.items(): adapter_id = adapter_id_map[name] checkpoint_path = f'{path}/adapter_model.bin' - state_dict = torch.load(checkpoint_path, map_location=device) + if not osp.exists(checkpoint_path): + checkpoint_path = f'{path}/adapter_model.safetensors' + state_dict = load_state_dict(checkpoint_path, map_location=device) if hasattr(model, 'load_lora_weights'): model.load_lora_weights(state_dict.items(), adapter_id=adapter_id) diff --git a/lmdeploy/pytorch/models/phi3.py b/lmdeploy/pytorch/models/phi3.py index f9477fdab8..988fee11e5 100644 --- a/lmdeploy/pytorch/models/phi3.py +++ b/lmdeploy/pytorch/models/phi3.py @@ -226,7 +226,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -243,7 +242,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) @@ -435,7 +433,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) - - -class Phi3VForCausalLM(Phi3ForCausalLM): - ... diff --git a/lmdeploy/pytorch/models/phi3_moe.py b/lmdeploy/pytorch/models/phi3_moe.py index 080f5e996c..7d0572513a 100644 --- a/lmdeploy/pytorch/models/phi3_moe.py +++ b/lmdeploy/pytorch/models/phi3_moe.py @@ -7,7 +7,7 @@ from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, LayerNorm, RopeType from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear -from lmdeploy.pytorch.nn.moe import FusedMoE +from lmdeploy.pytorch.nn.moe import build_fused_moe from lmdeploy.pytorch.nn.rotary_embedding import (LongRoPEScalingParameters, build_rotary_embedding) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -180,7 +180,7 @@ def __init__(self, is_tp=False, ) - self.experts = FusedMoE( + self.experts = build_fused_moe( self.hidden_dim, self.ffn_dim, self.num_experts, @@ -448,12 +448,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts = self.config.num_local_experts expert_params_mapping = [] for exp_id in range(num_experts): - gate_param = ('.experts.gate_up_weights', - f'.experts.{exp_id}.w1.weight', exp_id, 'gate') - up_param = ('.experts.gate_up_weights', - f'.experts.{exp_id}.w3.weight', exp_id, 'up') - down_param = ('.experts.down_weights', - f'.experts.{exp_id}.w2.weight', exp_id, 'down') + gate_param = ('.experts.gate_up', f'.experts.{exp_id}.w1', exp_id, + 'gate') + up_param = ('.experts.gate_up', f'.experts.{exp_id}.w3', exp_id, + 'up') + down_param = ('.experts.down', f'.experts.{exp_id}.w2', exp_id, + 'down') expert_params_mapping += [gate_param, up_param, down_param] params_dict = dict(self.named_parameters()) diff --git a/lmdeploy/pytorch/models/phi3_v.py b/lmdeploy/pytorch/models/phi3_v.py new file mode 100644 index 0000000000..c4bf72c767 --- /dev/null +++ b/lmdeploy/pytorch/models/phi3_v.py @@ -0,0 +1,476 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig + +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn.linear import build_rowwise_linear +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + +from .phi3 import Phi3ForCausalLM, Phi3Model +from .utils.model import DeployModelMixin + +CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(attention_dropout=0.0, + dropout=0.0, + hidden_act='quick_gelu', + hidden_size=1024, + image_size=336, + initializer_factor=1.0, + initializer_range=0.02, + intermediate_size=4096, + layer_norm_eps=1e-05, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + projection_dim=768) + + +class Phi3ImageEmbedding(nn.Module): + """image embedding.""" + + def __init__(self, + config: PretrainedConfig, + wte=None, + dtype: torch.dtype = None, + device: torch.device = None, + **kwargs): + super().__init__() + self.config = config + hidden_size = config.n_embd if hasattr( + config, 'n_embd') else config.hidden_size + + self.wte = wte + + if (isinstance(config.img_processor, dict) and + config.img_processor.get('name', None) == 'clip_vision_model'): + assert 'model_name' in config.img_processor, ( + 'model_name must be provided for CLIPVisionModel') + assert 'image_dim_out' in config.img_processor, ( + 'image_dim_out must be provided for CLIPVisionModel') + assert 'num_img_tokens' in config.img_processor, ( + 'num_img_tokens must be provided for CLIPVisionModel') + assert config.img_processor[ + 'model_name'] == 'openai/clip-vit-large-patch14-336' + clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG + self.img_processor = CLIPVisionModel(clip_config).to(device).to( + dtype) + image_dim_out = config.img_processor['image_dim_out'] + self.num_img_tokens = config.img_processor['num_img_tokens'] + else: + raise NotImplementedError( + f'img_processor = {config.img_processor}, not implemented') + + self.image_dim_out = image_dim_out + self.img_sizes = None + + self.use_hd_transform = kwargs.get('use_hd_transform', False) + self.with_learnable_separator = kwargs.get('with_learnable_separator', + False) + self.hd_transform_order = kwargs.get('hd_transform_order', 'glb_sub') + # with_hd_transform and with_learnable_separator should have same value + assert (self.use_hd_transform == self.with_learnable_separator), ( + 'use_hd_transform and with_learnable_separator ' + 'should have same value') + if self.with_learnable_separator: + assert self.use_hd_transform, ( + 'learnable separator is only for hd transform') + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter( + torch.empty([1, 1, self.image_dim_out * 4], + dtype=dtype, + device=device)) + self.sub_GN = nn.Parameter( + torch.empty([1, 1, 1, self.image_dim_out * 4], + dtype=dtype, + device=device)) + + projection_cls = kwargs.get('projection_cls', 'linear') + if projection_cls == 'linear': + self.img_projection = nn.Linear(image_dim_out, + hidden_size, + dtype=dtype, + device=device) + elif projection_cls == 'mlp' and self.use_hd_transform: + dim_projection = hidden_size + depth = 2 + layers = [ + nn.Linear(image_dim_out * 4, + dim_projection, + dtype=dtype, + device=device) + ] + for _ in range(1, depth): + layers.extend([ + nn.GELU(), + nn.Linear(dim_projection, + dim_projection, + dtype=dtype, + device=device) + ]) + self.img_projection = nn.Sequential(*layers) + elif projection_cls == 'mlp': + dim_projection = hidden_size + depth = 2 + layers = [ + nn.Linear(image_dim_out, + dim_projection, + dtype=dtype, + device=device) + ] + for _ in range(1, depth): + layers.extend([ + nn.GELU(), + nn.Linear(dim_projection, + dim_projection, + dtype=dtype, + device=device) + ]) + self.img_projection = nn.Sequential(*layers) + else: + raise NotImplementedError( + f'projection_cls = {projection_cls}, not implemented') + + self.vocab_size = config.vocab_size + self.img_features = None + + if isinstance(config.img_processor, dict): + self.layer_idx = config.img_processor.get('layer_idx', -2) + self.type_feature = config.img_processor.get( + 'type_feature', 'patch') + else: + self.layer_idx = -2 + self.type_feature = 'patch' + + def get_img_features(self, + img_embeds: torch.FloatTensor) -> torch.FloatTensor: + LAYER_IDX = self.layer_idx + TYPE_FEATURE = self.type_feature + + img_processor_output = self.img_processor(img_embeds, + output_hidden_states=True) + img_feature = img_processor_output.hidden_states[LAYER_IDX] + + if TYPE_FEATURE == 'patch': + patch_feature = img_feature[:, 1:] + return patch_feature + + if TYPE_FEATURE == 'cls_patch': + return img_feature + + raise NotImplementedError + + def forward( + self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + image_sizes=None, + image_mask: torch.Tensor = None, + ) -> torch.FloatTensor: + """forward.""" + + target_device = pixel_values.device + target_dtype = pixel_values.dtype + + img_embeds = pixel_values + img_sizes = image_sizes + img_sizes = img_sizes.cpu() + + if self.use_hd_transform and img_sizes is not None and len(img_sizes): + assert img_embeds.ndim == 5, f'img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform' # noqa E501 + # img_embeds: (num_images, max_num_crops, 3, H, W) + # img_sizes: (num_images, 2).view(1, -1) + + bs = img_embeds.shape[0] + # Nx(HW)xC + img_features = self.get_img_features(img_embeds.flatten(0, 1)) + base_feat_height = base_feat_width = int( + img_features.shape[1]**0.5) + + assert base_feat_height == 24 and base_feat_width == 24, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect 24x24 features for hd transform' # noqa E501 + + # bs x max_num_crops x (24x24) x C + img_features = img_features.view( + bs, -1, base_feat_height * base_feat_width, self.image_dim_out) + C = self.image_dim_out + H = base_feat_height + + output_imgs = [] + output_len = [] + # training is tensor, inference is list + if isinstance(img_sizes, torch.Tensor): + img_sizes = img_sizes.view(-1, 2) + for _bs in range(bs): + h, w = img_sizes[_bs] + h = h // 336 + w = w // 336 + B_ = h * w + + # 1 x (24x24) x 1024 + global_img_feature = img_features[_bs, :1] + + # 1 x 12 x 12 x 4096 + glb_img = global_img_feature.reshape( + 1, H // 2, 2, H // 2, 2, + C).permute(0, 1, 3, 2, 4, + 5).reshape(1, H // 2, H // 2, 4 * C) + temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1) + + # 1 x 156 x 4096 + glb_img = torch.cat([glb_img, temp_glb_GN], + dim=2).reshape(1, -1, 4 * C) + + # (max_num_crops-1) x (12x12) x C + sub_img = img_features[_bs, 1:] + # 16x574x1024 + # get rid of padding sub_img + sub_img = sub_img[:B_] + + # (num_crops, 12, 2, 12, 2, 1024) + # ->(num_crops, 12, 12, 2, 2, 1024) + # -> (num_crops, 12*12, 4*1024) + sub_img = (sub_img.reshape(B_, H // 2, 2, H // 2, 2, + C).permute(0, 1, 3, 2, 4, 5)) + sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute( + 0, 1, 3, 2, 4, 5).reshape(1, h * 12, w * 12, 4 * C) + temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1) + sub_img = torch.cat([sub_img, temp_sub_GN], + dim=2).reshape(1, -1, 4 * C) + # (1, num_img_tokens, 1024*4) + + # glb + sub + if self.hd_transform_order == 'glb_sub': + output_imgs.append( + torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) + elif self.hd_transform_order == 'sub_glb': + output_imgs.append( + torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) + else: + raise NotImplementedError( + f'hd_transform_order = {self.hd_transform_order}' + ) # noqa E501 + + temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12) + assert temp_len == output_imgs[-1].shape[ + 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' # noqa E501 + output_len.append(temp_len) + + img_set_tensor = [] + for _output_img in output_imgs: + img_feature_proj = self.img_projection( + _output_img.to(target_device).to(target_dtype)) + img_feature_proj = img_feature_proj.flatten(0, 1) + img_set_tensor.append(img_feature_proj) + img_set_tensor = torch.cat(img_set_tensor)[None] + elif img_embeds.ndim == 4: + tt = (self.get_img_features(img_embeds).to(target_device).to( + target_dtype).reshape(-1, self.image_dim_out)) + img_set_tensor = self.img_projection( + tt) # adapted visual features. + elif img_embeds.ndim == 3: + tt = (img_embeds.to(target_device).to(target_dtype).view( + -1, self.image_dim_out)) + img_set_tensor = self.img_projection( + tt) # adapted visual features. + else: + raise NotImplementedError + + hidden_states = self.wte(input_ids) + + hidden_states.masked_scatter_(image_mask[..., None], img_set_tensor) + + return hidden_states + + +class Phi3VModel(Phi3Model): + """phi3v model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__(config=config, dtype=dtype, device=device) + + self.vision_embed_tokens = None + if isinstance(config.embd_layer, dict): + # vision embedding layer + embedding_config = { + 'embedding_cls': config.embd_layer['embedding_cls'], + **config.embd_layer + } + self.vision_embed_tokens = Phi3ImageEmbedding( + config, + wte=self.embed_tokens, + dtype=dtype, + device=device, + **embedding_config) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + image_mask: torch.Tensor = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of LlamaModel.forward.""" + + if inputs_embeds is None and pixel_values is not None: + inputs_embeds = self.vision_embed_tokens( + input_ids, + pixel_values, + image_sizes, + image_mask, + ) + + return super().forward( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + +class Phi3VForCausalLM(Phi3ForCausalLM, DeployModelMixin): + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__(config, ctx_mgr, dtype=dtype, device=device) + self.config = config + self.ctx_mgr = ctx_mgr + # build model + self.model = Phi3VModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + self.input_processor = Phi3VInputProcessor(config, dtype) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + pixel_values: torch.Tensor = None, + image_sizes: torch.Tensor = None, + image_mask: torch.Tensor = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """forward.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_sizes=image_sizes, + image_mask=image_mask, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare input.""" + output = super().prepare_inputs_for_generation( + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + context=context) + + # vision inputs + pixel_values = None + if context.input_multimodals is not None: + input_mms = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + input_mms = [data for im_data in input_mms for data in im_data] + if len(input_mms) > 0: + pixel_values = torch.cat([data.data for data in input_mms]) + image_sizes = torch.cat( + [data.meta['image_sizes'] for data in input_mms]) + image_token_id = input_mms[0].meta['image_token_id'] + image_mask = output['input_ids'] == image_token_id + output['pixel_values'] = pixel_values + output['image_sizes'] = image_sizes + output['image_mask'] = image_mask + + return output + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + super().load_weights(weights) + + vis_prefix = 'vision_embed_tokens.' + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if not (vis_prefix in name): + continue + param = params_dict[name] + load_weight(param, loaded_weight) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class Phi3VInputProcessor(BaseModelInputProcessor): + """Phi3V input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + image_sizes = input_mm['image_sizes'] + offset = input_mm['offset'] + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict( + image_sizes=image_sizes, + image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/q_modules.py b/lmdeploy/pytorch/models/q_modules.py index 001fab7a60..8379bb18c9 100644 --- a/lmdeploy/pytorch/models/q_modules.py +++ b/lmdeploy/pytorch/models/q_modules.py @@ -34,13 +34,17 @@ class QRMSNorm(nn.Module): """It performs traditional RMS normalization and then quantizes the output to 8-bit integers.""" - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, eps=1e-6, quant_dtype=torch.int8): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + self.quant_dtype = quant_dtype @classmethod - def from_float(cls, mod: nn.Module, initialization: bool = True): + def from_float(cls, + mod: nn.Module, + initialization: bool = True, + quant_dtype=torch.int8): """Class method to create a QRMSNorm instance from a floating-point module. @@ -49,7 +53,7 @@ def from_float(cls, mod: nn.Module, initialization: bool = True): """ hidden_size = mod.weight.shape[0] eps = mod.variance_epsilon - q_mod = cls(hidden_size, eps) + q_mod = cls(hidden_size, eps, quant_dtype=quant_dtype) if initialization: q_mod.weight = nn.Parameter(mod.weight.detach()) return q_mod @@ -62,7 +66,10 @@ def forward(self, hidden_states): with its scale factor. """ hidden_states_quant, rms_scale = rms_norm_dynamic_quant( - hidden_states, self.weight, self.variance_epsilon) + hidden_states, + self.weight, + self.variance_epsilon, + quant_dtype=self.quant_dtype) return QTensor(hidden_states_quant, rms_scale) @@ -83,16 +90,18 @@ def __init__(self, out_features: int, bias: bool = True, device=None, - dtype=None) -> None: + dtype=None, + quant_dtype=torch.int8) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.in_features = in_features self.out_features = out_features + self.quant_dtype = quant_dtype self.register_buffer( 'weight', torch.empty((out_features, in_features), device=device, - dtype=torch.int8)) + dtype=quant_dtype)) self.register_buffer( 'scale', torch.empty((out_features, 1), device=device, dtype=torch.float32)) @@ -103,7 +112,10 @@ def __init__(self, self.register_parameter('bias', None) @classmethod - def from_float(cls, mod: nn.Module, initialization: bool = True): + def from_float(cls, + mod: nn.Module, + initialization: bool = True, + quant_dtype=torch.int8): """Class method to create a QLinear instance from a floating-point module. @@ -114,11 +126,12 @@ def from_float(cls, mod: nn.Module, initialization: bool = True): mod.out_features, mod.bias is not None, device=mod.weight.device, - dtype=mod.weight.dtype) + dtype=mod.weight.dtype, + quant_dtype=quant_dtype) if initialization: - weight_quant, scale = per_channel_quant(mod.weight.detach(), 8, - torch.int8) + weight_quant, scale = per_channel_quant(mod.weight.detach(), + quant_dtype) q_mod.weight.data = weight_quant q_mod.scale = scale @@ -137,7 +150,8 @@ def forward(self, input): """ if isinstance(input, torch.Tensor): - input_quant, input_scale = per_token_quant_int8(input, 1e-7) + input_quant, input_scale = per_token_quant_int8( + input, 1e-7, quant_dtype=self.quant_dtype) else: assert isinstance(input, QTensor) input_quant, input_scale = input.tensor, input.scale diff --git a/lmdeploy/pytorch/models/qwen.py b/lmdeploy/pytorch/models/qwen.py index bf856461a3..20e184bdf8 100644 --- a/lmdeploy/pytorch/models/qwen.py +++ b/lmdeploy/pytorch/models/qwen.py @@ -229,7 +229,6 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() - quantization_config = getattr(config, 'quantization_config', None) self.vocab_size = config.vocab_size self.embed_dim = config.hidden_size self.wte = nn.Embedding(self.vocab_size, @@ -263,7 +262,6 @@ def __init__(self, self.ln_f = RMSNorm(self.embed_dim, eps=config.layer_norm_epsilon, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/qwen2.py b/lmdeploy/pytorch/models/qwen2.py index 82be75e167..a26aa22d5a 100644 --- a/lmdeploy/pytorch/models/qwen2.py +++ b/lmdeploy/pytorch/models/qwen2.py @@ -29,7 +29,8 @@ def __init__(self, num_key_value_heads = config.num_key_value_heads hidden_size = config.hidden_size head_dim = getattr(config, 'head_dim', hidden_size // num_heads) - + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) # packed qkv self.qkv_proj = build_qkv_proj( hidden_size, @@ -40,7 +41,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) # rotary embedding self.apply_rotary_pos_emb = ApplyRotaryEmb() @@ -224,7 +225,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -241,7 +241,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/qwen2_moe.py b/lmdeploy/pytorch/models/qwen2_moe.py index 1aff14483a..de990592d5 100644 --- a/lmdeploy/pytorch/models/qwen2_moe.py +++ b/lmdeploy/pytorch/models/qwen2_moe.py @@ -13,7 +13,7 @@ SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) -from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK +from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin @@ -185,7 +185,7 @@ def __init__(self, self.softmax_topk = SoftmaxTopK(self.top_k) - self.experts = FusedMoE( + self.experts = build_fused_moe( self.hidden_dim, self.ffn_dim, self.num_experts, @@ -280,12 +280,10 @@ def __init__(self, device=device) # build attention layer norm - self.post_attention_layernorm = RMSNorm( - config.hidden_size, - config.rms_norm_eps, - quant_config=quantization_config, - dtype=dtype, - device=device) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + dtype=dtype, + device=device) def forward( self, @@ -330,7 +328,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -347,7 +344,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) @@ -531,14 +527,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts = self.config.num_experts expert_params_mapping = [] for exp_id in range(num_experts): - gate_param = ('.experts.gate_up_weights', - f'.experts.{exp_id}.gate_proj.weight', exp_id, - 'gate') - up_param = ('.experts.gate_up_weights', - f'.experts.{exp_id}.up_proj.weight', exp_id, 'up') - down_param = ('.experts.down_weights', - f'.experts.{exp_id}.down_proj.weight', exp_id, - 'down') + gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', + exp_id, 'gate') + up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', + exp_id, 'up') + down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', + exp_id, 'down') expert_params_mapping += [gate_param, up_param, down_param] params_dict = dict(self.named_parameters()) diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py index b10baaa4d5..bfd6e352f1 100644 --- a/lmdeploy/pytorch/models/qwen2_vl.py +++ b/lmdeploy/pytorch/models/qwen2_vl.py @@ -1,18 +1,24 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Callable, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import torch from torch import nn from transformers.configuration_utils import PretrainedConfig +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, - SiluAndMul, build_rotary_embedding) -from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, FlashAttention, + LayerNorm, RMSNorm, RopeType, SiluAndMul, + build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, + build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin, next_power_of_2 +from .utils.model import DeployModelMixin def _apply_mrope_selection(hidden_states: torch.Tensor, @@ -254,7 +260,6 @@ def __init__(self, self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.mrope_section = config.rope_scaling['mrope_section'] - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -271,7 +276,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) @@ -337,7 +341,337 @@ def get_input_embeddings(self): return self.embed_tokens -class Qwen2VLForConditionalGeneration(nn.Module, CudaGraphMixin): +class PatchEmbed(nn.Module): + """Patch Embed.""" + + def __init__(self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + dtype: torch.dtype = None, + device: torch.device = None) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + dtype=dtype, + device=device) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view(-1, self.in_channels, + self.temporal_patch_size, + self.patch_size, self.patch_size) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view( + -1, self.embed_dim) + return hidden_states + + +class VisionRotaryEmbedding(nn.Module): + """vision rotary embedding.""" + + def __init__(self, + dim: int, + theta: float = 10000.0, + device: torch.device = None) -> None: + super().__init__() + inv_freq = 1.0 / (theta**( + torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class VisionAttention(nn.Module): + """Vision attention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + dim = config.embed_dim + num_heads = config.num_heads + head_dim = dim // num_heads + self.head_dim = head_dim + + # packed qkv + self.qkv = build_qkv_proj( + dim, + num_q_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.attention = FlashAttention( + num_heads, + head_dim, + causal=False, + ) + + # o_proj + self.proj = build_rowwise_linear(dim, + dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor] + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + # qkv proj + qkv_states = self.qkv(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + q, k, v = self.qkv.split_qkv(qkv_states) + + cos, sin = rotary_pos_emb + q, k = self.apply_rotary_pos_emb(q, k, cos, sin) + + attn_output = self.attention( + q, + k, + v, + q_start_loc=cu_seqlens[:-1], + q_seqlens=cu_seqlens[1:] - cu_seqlens[:-1], + ) + + attn_output = attn_output.reshape(seq_length, -1) + + # o proj + attn_output = self.proj(attn_output) + return attn_output + + +class VisionMlp(nn.Module): + """Vision mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + dim = config.embed_dim + hidden_dim = int(config.embed_dim * config.mlp_ratio) + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.fc1 = build_colwise_linear( + dim, + hidden_dim, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + if config.hidden_act in [ + 'gelu', 'gelu_fast', 'quick_gelu', 'gelu_python' + ]: + self.act = nn.GELU() + else: + self.act = ACT2FN[config.hidden_act] + + # down + self.fc2 = build_rowwise_linear(hidden_dim, + dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + return self.fc2(self.act(self.fc1(x))) + + +class Qwen2VLVisionBlock(nn.Module): + """Vision block.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + self.norm1 = LayerNorm(config.embed_dim, + eps=1e-6, + dtype=dtype, + device=device) + self.norm2 = LayerNorm(config.embed_dim, + eps=1e-6, + dtype=dtype, + device=device) + + self.attn = VisionAttention(config, dtype=dtype, device=device) + + self.mlp = VisionMlp(config, dtype=dtype, device=device) + + def forward(self, + hidden_states, + cu_seqlens, + rotary_pos_emb, + residual: Optional[torch.Tensor] = None) -> torch.Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.norm1(hidden_states) + else: + hidden_states, residual = self.norm1(hidden_states, residual) + + hidden_states = self.attn(hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb) + + hidden_states, residual = self.norm2(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class PatchMerger(nn.Module): + """PatchMerger.""" + + def __init__(self, + dim: int, + context_dim: int, + spatial_merge_size: int = 2, + dtype: torch.dtype = None, + device: torch.device = None) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = nn.LayerNorm(context_dim, + eps=1e-6, + dtype=dtype, + device=device) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, + self.hidden_size, + dtype=dtype, + device=device), + nn.GELU(), + nn.Linear(self.hidden_size, dim, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +class Qwen2VisionTransformerPretrainedModel(nn.Module): + """Vision transformer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = PatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.embed_dim, + dtype=dtype, + device=device, + ) + + head_dim = config.embed_dim // config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2, + device=device) + + self.blocks = nn.ModuleList([ + Qwen2VLVisionBlock(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.depth) + ]) + self.merger = PatchMerger(dim=config.hidden_size, + context_dim=config.embed_dim, + spatial_merge_size=config.spatial_merge_size, + dtype=dtype, + device=device) + + def rot_pos_emb(self, grid_thw): + """rotary position embedding.""" + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor) -> torch.Tensor: + """forward.""" + hidden_states = self.patch_embed(hidden_states) + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + + residual = None + for blk in self.blocks: + hidden_states, residual = blk(hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + residual=residual) + + hidden_states = hidden_states + residual + + return self.merger(hidden_states) + + +class Qwen2VLForConditionalGeneration(nn.Module, DeployModelMixin, + CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -360,6 +694,16 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr + + # preprocessor + self.input_processor = Qwen2VLInputProcessor(self.config) + + # build vision model + self.visual = Qwen2VisionTransformerPretrainedModel( + config.vision_config, + dtype=dtype, + device=device, + ) # build model self.model = Qwen2Model(config, dtype=dtype, device=device) # build lm_head @@ -377,9 +721,26 @@ def forward( attn_metadata: Any = None, inputs_embeds: torch.Tensor = None, mrope_position_ids: torch.Tensor = None, + pixel_values: torch.Tensor = None, + vis_cu_seqlens: torch.Tensor = None, + vis_pos_emb: torch.Tensor = None, + image_mask: torch.Tensor = None, **kwargs, ): """model forward, return logits.""" + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + if pixel_values is not None: + dtype = inputs_embeds.dtype + pixel_values = pixel_values.to(dtype) + vis_pos_emb = (vis_pos_emb[0].to(dtype), + vis_pos_emb[1].to(dtype)) + image_embeds = self.visual(pixel_values, + cu_seqlens=vis_cu_seqlens, + rotary_pos_emb=vis_pos_emb) + inputs_embeds = inputs_embeds.masked_scatter( + image_mask[..., None], image_embeds) + hidden_states = self.model( input_ids=input_ids, position_ids=position_ids, @@ -416,6 +777,36 @@ def prepare_inputs_for_generation( position_ids = context.position_ids attn_metadata = context.attn_metadata + pixel_values = None + vis_cu_seqlens = None + vis_pos_emb = None + image_mask = None + if context.input_multimodals is not None: + image_data = [ + input_mm['image'] for input_mm in context.input_multimodals + ] + + if len(image_data) > 0: + # flatten batch + image_data = [ + data for im_data in image_data for data in im_data + ] + pixel_values = torch.cat([data.data for data in image_data]) + image_token_id = image_data[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + grid_thw = torch.cat( + [data.meta['grid_thw'] for data in image_data]).cpu() + vis_pos_emb = self.visual.rot_pos_emb(grid_thw) + vis_cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).to(pixel_values.device) + vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, + dtype=torch.int32) + vis_pos_emb = vis_pos_emb.repeat(1, 2) + vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin()) + + mrope_position_ids = getattr(context, 'mrope_position_ids', None) + # process vision embeddings vision_embeddings = context.input_embeddings vision_embedding_indexing = context.input_embedding_indexing @@ -433,7 +824,11 @@ def prepare_inputs_for_generation( past_key_values=past_key_values, attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, - mrope_position_ids=context.mrope_position_ids, + mrope_position_ids=mrope_position_ids, + pixel_values=pixel_values, + vis_cu_seqlens=vis_cu_seqlens, + vis_pos_emb=vis_pos_emb, + image_mask=image_mask, ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -450,8 +845,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - if 'visual' in name: - continue if 'rotary_emb.inv_freq' in name: continue if ('rotary_emb.cos_cached' in name @@ -467,8 +860,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): load_weight(param, loaded_weight, shard_id=shard_id) break else: - param = params_dict[name] - load_weight(param, loaded_weight) + if '.qkv.' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): """make cudagraph buffers from forward inputs.""" @@ -510,3 +910,130 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): 'mrope_position_ids'] return new_inputs + + def _update_model_meta_decoding(self, context: StepContext): + """update model meta for decoding.""" + model_metas = context.model_metas + position_ids = context.position_ids + + mrope_deltas = [meta['mrope_delta'] for meta in model_metas] + mrope_deltas = position_ids.new_tensor(mrope_deltas) + mrope_position_ids = position_ids + mrope_deltas[None] + mrope_position_ids = mrope_position_ids.expand(3, -1) + + context.mrope_position_ids = mrope_position_ids + return model_metas + + def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): + """get mrope ids.""" + t, h, w = grid_thw + h //= 2 + w //= 2 + stride = torch.tensor([h * w, w, 1], device=device)[:, None] + size = torch.tensor([t, h, w], device=device)[:, None] + pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) + pos_ids = pos_ids // stride % size + return pos_ids + + def _update_model_meta_prefilling(self, context: StepContext): + """update model meta for prefilling.""" + model_metas = context.model_metas + input_multimodals = context.input_multimodals + if input_multimodals is None: + input_multimodals = [None] * len(model_metas) + position_ids = context.position_ids + batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) + mrope_position_ids = [] + new_model_metas = [] + for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, + input_multimodals): + images = [] + if input_mm is not None: + images = input_mm['image'] + if model_meta is None or 'mrope_delta' not in model_meta: + mrope_delta = 0 + else: + mrope_delta = model_meta['mrope_delta'] + + pos_start = pos_ids[0].item() + mrope_pos_ids = pos_ids + mrope_delta + mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() + for img in images: + grid_thw = img.meta['grid_thw'][0].tolist() + _, h, w = grid_thw + h //= 2 + w //= 2 + num_pad = img.end - img.start - max(h, w) + mrope_delta -= num_pad + fill_start = img.start - pos_start + fill_end = img.end - pos_start + img_pos_ids = self._get_multimodal_pos_ids( + grid_thw, pos_ids.device) + img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] + mrope_pos_ids[:, fill_end:] -= num_pad + mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids + + mrope_position_ids.append(mrope_pos_ids) + new_model_metas.append(dict(mrope_delta=mrope_delta)) + + mrope_position_ids = torch.cat(mrope_position_ids, dim=1) + context.mrope_position_ids = mrope_position_ids + + return new_model_metas + + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """update model meta.""" + if context.is_decoding: + return self._update_model_meta_decoding(context) + else: + return self._update_model_meta_prefilling(context) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +InputMultiModalType = List[Dict[str, Any]] + + +class Qwen2VLInputProcessor(BaseModelInputProcessor): + """qwen2 input processor.""" + + def __init__(self, config: PretrainedConfig) -> None: + self.config = config + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'] + image_grid_thw = input_mm['image_grid_thw'] + offset = input_mm['offset'] + start = offset + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor(data=pixel_values, + start=start, + end=start + num_pad, + meta=dict( + grid_thw=image_grid_thw, + image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/utils/model.py b/lmdeploy/pytorch/models/utils/model.py new file mode 100644 index 0000000000..99bd4c4bfb --- /dev/null +++ b/lmdeploy/pytorch/models/utils/model.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Iterable, List, Optional, Tuple + +import torch + +from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor +from lmdeploy.pytorch.model_inputs import StepContext + + +class DeployModelMixin: + + def forward(self, *args, **kwargs): + """forward of model.""" + raise NotImplementedError('Not Implemented') + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + raise NotImplementedError('Not Implemented') + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + raise NotImplementedError('Not Implemented') + + def get_logits(self, hidden_states: torch.Tensor): + """compute logits of the model output.""" + return hidden_states + + def update_weights(self): + """update weights.""" + pass + + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """update model meta.""" + return None + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return None diff --git a/lmdeploy/pytorch/models/utils/multimodal.py b/lmdeploy/pytorch/models/utils/multimodal.py new file mode 100644 index 0000000000..aebcaf4073 --- /dev/null +++ b/lmdeploy/pytorch/models/utils/multimodal.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs + +PreparedInputs = Tuple[List[int], MultiModalInputs] + + +class MultiModalMixin: + + def prepare_multimodal_input(self, input_ids, input_multimodals, + **kwargs) -> PreparedInputs: + """prepare multimodals inputs.""" + raise NotImplementedError('prepare input not implemented.') diff --git a/lmdeploy/pytorch/multimodal/__init__.py b/lmdeploy/pytorch/multimodal/__init__.py new file mode 100644 index 0000000000..c3e8c6a16f --- /dev/null +++ b/lmdeploy/pytorch/multimodal/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .data_type import MultiModalData, MultiModalTensor + +__all__ = ['MultiModalData', 'MultiModalTensor'] diff --git a/lmdeploy/pytorch/multimodal/data_type.py b/lmdeploy/pytorch/multimodal/data_type.py new file mode 100644 index 0000000000..886c7ffbd0 --- /dev/null +++ b/lmdeploy/pytorch/multimodal/data_type.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass, fields +from typing import Any, Dict, List, Union + +import torch +from torch import Tensor +from torch import distributed as dist + + +class MultiModalData: + pass + + +MultiModalDataList = List[MultiModalData] + +NestedTensor = Union[Tensor, List[Tensor]] + + +def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'): + """broadcast tensor.""" + if value.device.type == 'meta': + value = torch.empty_like(value, device=device) + dist.broadcast(value, src) + return value + + +@dataclass +class MultiModalTensor: + data: NestedTensor + start: int + end: int = None + encoder_len: int = None + meta: Dict[str, Any] = None + + def __post_init__(self): + if self.end is None: + self.end = self.start + + def to_device(self, device: str, non_blocking: bool = False): + """to device.""" + out_dict = dict() + for f in fields(self): + k = f.name + if k in ('data', 'meta'): + continue + v = getattr(self, k) + out_dict[k] = v + + if isinstance(self.data, Tensor): + data = self.data.to(device=device, non_blocking=non_blocking) + else: + data = [ + d.to(device=device, non_blocking=non_blocking) + for d in self.data + ] + out_dict['data'] = data + + new_meta = None + if self.meta is not None: + new_meta = dict() + for k, v in self.meta.items(): + if isinstance(v, Tensor): + v = v.to(device=device, non_blocking=non_blocking) + elif hasattr(v, 'to_device'): + v = v.to_device(device=device, non_blocking=non_blocking) + new_meta[k] = v + + out_dict['meta'] = new_meta + return MultiModalTensor(**out_dict) + + def broadcast(self): + """broadcast inputs tensors.""" + out_dict = dict() + for f in fields(self): + k = f.name + if k in ('data', 'meta'): + continue + v = getattr(self, k) + out_dict[k] = v + + if isinstance(self.data, Tensor): + data = _broadcast_tensor(self.data) + else: + data = [_broadcast_tensor(d) for d in self.data] + out_dict['data'] = data + + new_meta = None + if self.meta is not None: + new_meta = dict() + for k, v in self.meta.items(): + if isinstance(v, Tensor): + v = _broadcast_tensor(v) + self.meta[k] = v + elif hasattr(v, 'to_device'): + assert hasattr(v, 'broadcast') + v = v.broadcast() + self.meta[k] = v + new_meta[k] = v + + out_dict['meta'] = new_meta + return MultiModalTensor(**out_dict) + + +MultiModalInputs = Dict[str, List[MultiModalTensor]] diff --git a/lmdeploy/pytorch/multimodal/image_type.py b/lmdeploy/pytorch/multimodal/image_type.py new file mode 100644 index 0000000000..19211a381f --- /dev/null +++ b/lmdeploy/pytorch/multimodal/image_type.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass +from typing import Any, ClassVar, Dict + +from PIL import Image + +from .data_type import MultiModalData + + +@dataclass +class ImageData(MultiModalData): + data: Image + loc: int + meta: Dict[str, Any] = None + type: ClassVar[str] = 'image' diff --git a/lmdeploy/pytorch/nn/__init__.py b/lmdeploy/pytorch/nn/__init__.py index 63df9a5ae9..4705115bf4 100644 --- a/lmdeploy/pytorch/nn/__init__.py +++ b/lmdeploy/pytorch/nn/__init__.py @@ -2,7 +2,7 @@ # attention module is modified from: # https://github.com/vllm-project/vllm/blob/main/vllm/attention/ from .activation import GeluAndMul, SiluAndMul # noqa: F401 -from .attention import Attention # noqa: F401 +from .attention import Attention, FlashAttention # noqa: F401 from .norm import LayerNorm, RMSNorm # noqa: F401 from .rotary_embedding import ApplyRotaryEmb # noqa: F401 from .rotary_embedding import RopeType # noqa: F401 diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py index 26f1034d36..684c8122f8 100644 --- a/lmdeploy/pytorch/nn/attention.py +++ b/lmdeploy/pytorch/nn/attention.py @@ -9,6 +9,14 @@ from .utils import get_distribute_size +def _update_num_heads(num_heads: int, num_kv_heads: int): + """update heads.""" + world_size, rank = get_world_rank() + num_heads = get_distribute_size(num_heads, world_size, rank) + num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) + return num_heads, num_kv_heads + + class Attention(nn.Module): """Attention layer.""" @@ -22,15 +30,19 @@ def __init__( alibi: bool = False, sliding_window: int = None, logit_softcapping: float = None, - replicate_kv: bool = False, + causal: bool = True, **kwargs, ): super().__init__() - num_heads, num_kv_heads = self._update_num_heads( - num_heads, num_kv_heads, replicate_kv) + if num_kv_heads is None: + num_kv_heads = num_heads + if v_head_size is None: + v_head_size = head_size + num_heads, num_kv_heads = _update_num_heads(num_heads, num_kv_heads) layer_backend = get_backend() - impl_builder = layer_backend.get_layer_impl_builder(OpType.Attention) + impl_builder = layer_backend.get_layer_impl_builder( + OpType.PagedAttention) self.impl = impl_builder.build( num_heads=num_heads, @@ -41,18 +53,10 @@ def __init__( alibi=alibi, sliding_window=sliding_window, logit_softcapping=logit_softcapping, + causal=causal, **kwargs, ) - def _update_num_heads(self, num_heads: int, num_kv_heads: int, - replicate_kv: bool): - """update heads.""" - world_size, rank = get_world_rank() - num_heads = get_distribute_size(num_heads, world_size, rank) - if not replicate_kv: - num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) - return num_heads, num_kv_heads - def forward( self, query: torch.Tensor, @@ -77,3 +81,75 @@ def forward( v_scales_zeros=v_scales_zeros, inplace=inplace, ) + + +class FlashAttention(nn.Module): + """flash attention w/o paging.""" + + def __init__( + self, + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logit_softcapping: float = None, + **kwargs, + ): + super().__init__() + if num_kv_heads is None: + num_kv_heads = num_heads + if v_head_dim is None: + v_head_dim = head_dim + num_heads, num_kv_heads = _update_num_heads(num_heads, num_kv_heads) + + layer_backend = get_backend() + + impl_builder = layer_backend.get_layer_impl_builder( + OpType.FlashAttention) + + self.impl = impl_builder.build( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + v_head_dim=v_head_dim, + causal=causal, + sliding_window=sliding_window, + logit_softcapping=logit_softcapping, + **kwargs, + ) + + def forward(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + q_start_loc: torch.Tensor, + q_seqlens: torch.Tensor, + kv_start_loc: torch.Tensor = None, + kv_seqlens: torch.Tensor = None, + max_q_seqlen: int = None) -> torch.Tensor: + """forward.""" + + if max_q_seqlen is None: + max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) + + if kv_start_loc is None and kv_seqlens is None: + kv_start_loc = q_start_loc + kv_seqlens = q_seqlens + + assert kv_start_loc is not None + assert kv_seqlens is not None + + return self.impl.forward( + query, + key, + value, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_start_loc=kv_start_loc, + kv_seqlens=kv_seqlens, + max_q_seqlen=max_q_seqlen, + ) diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index 08040ee00c..73d0ef918d 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -12,7 +12,7 @@ from ..backends import OpType, get_backend from ..backends.lora import AdapterInfo -from .utils import div_up, get_distribute_size +from .utils import chunk_aligned, div_up, get_distribute_size logger = get_logger('lmdeploy') @@ -25,37 +25,30 @@ def _check_qkv_split_layout(layout: str): f'but get: {layout}') -def _chunk_align(weight: torch.Tensor, chunks: int, dim: int, align: int): - """chunk aligned.""" - if align == 1: - return weight.chunk(chunks, dim=dim) - size = weight.size(dim) - assert size % align == 0 - aligned_size = size // align - align_per_chunk = div_up(aligned_size, chunks) - sections = [align_per_chunk] * (chunks - 1) - sections += [aligned_size - align_per_chunk * (chunks - 1)] - sections = [sec * align for sec in sections] - return weight.split(sections, dim=dim) +_chunk_align = chunk_aligned class QKVMixin: """qkv mixin.""" - def _get_qkv_out_features(self, num_q_heads: int, num_kv_heads: int, - head_size: int, head_size_v: int): + def _get_qkv_out_features(self, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + head_size_v: int, + num_replicate_kv_heads: int = 1): """get io features.""" - all_out_features = (num_q_heads * head_size, num_kv_heads * head_size, - num_kv_heads * head_size_v) + num_kv_heads_real = num_kv_heads // num_replicate_kv_heads + all_out_features = (num_q_heads * head_size, + num_kv_heads_real * head_size, + num_kv_heads_real * head_size_v) return all_out_features - def _update_num_heads(self, num_q_heads: int, num_kv_heads: int, - replicate_kv: bool): + def _update_num_heads(self, num_q_heads: int, num_kv_heads: int): """update num heads.""" world_size, rank = get_world_rank() num_q_heads = get_distribute_size(num_q_heads, world_size, rank) - if not replicate_kv: - num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) + num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) return num_q_heads, num_kv_heads @@ -159,6 +152,239 @@ def weight_loader_B(self, param: nn.Parameter, loaded_weight: torch.Tensor, param_r.copy_(loaded_weight.t()) +class BlockedF8Linear(nn.Module): + """blocked f8 linear.""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + fp8_dtype: torch.dtype = torch.float8_e4m3fn, + colwise: bool = True, + is_tp: bool = False, + all_reduce: bool = True, + ): + super().__init__() + if device is None: + device = torch.device('cpu') + if dtype is None: + dtype = torch.float16 + if is_tp: + in_features, out_features = self._get_io_features( + in_features, out_features, colwise) + impl_builder = get_backend().get_layer_impl_builder( + OpType.LinearBlockedF8) + self.impl = impl_builder.build(in_features, + out_features, + block_size=128, + bias=bias is not None, + dtype=dtype) + self.block_size = 128 + self.fp8_dtype = fp8_dtype + weight, scale, bias = self.create_weights(in_features, out_features, + bias, dtype, device) + weight = torch.nn.Parameter(weight, requires_grad=False) + weight.weight_loader = self.weight_loader + scale = torch.nn.Parameter(scale, requires_grad=False) + scale.weight_loader = self.weight_loader + if bias is not None: + bias = torch.nn.Parameter(bias, requires_grad=False) + bias.weight_loader = self.weight_loader + self.register_parameter('weight', weight) + self.register_parameter('scale', scale) + self.register_parameter('bias', bias) + + self.in_features = in_features + self.out_features = out_features + self.lora_adapters = nn.ModuleDict() + self.is_tp = is_tp + self.colwise = colwise + self.all_reduce = all_reduce + + def _get_io_features(self, in_features: int, out_features: int, + colwise: bool): + """get io features.""" + world_size, rank = get_world_rank() + if colwise: + out_features = get_distribute_size(out_features, world_size, rank) + else: + in_features = get_distribute_size(in_features, world_size, rank) + return in_features, out_features + + def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, rank: int, + world_size: int): + """weight loader for colwise linear.""" + weight = loaded_weight.chunk(world_size, 0)[rank] + return default_weight_loader(param, weight) + + def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, rank: int, + world_size: int): + """weight loader for rowwise linear.""" + if loaded_weight.dim() == 2: + weight = loaded_weight.chunk(world_size, 1)[rank] + return default_weight_loader(param, weight) + else: + # bias + if rank != 0: + loaded_weight = torch.zeros_like(loaded_weight) + return default_weight_loader(param, loaded_weight) + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + """weight loader.""" + if not self.is_tp: + return default_weight_loader(param, loaded_weight) + + world_size, rank = get_world_rank() + if self.colwise: + return self._weight_loader_tp_colwise(param, loaded_weight, rank, + world_size) + else: + return self._weight_loader_tp_rowwise(param, loaded_weight, rank, + world_size) + + def create_weights(self, in_features: int, out_features: int, bias: bool, + dtype: torch.dtype, device: torch.device): + """create weights.""" + weight = torch.empty((out_features, in_features), + dtype=self.fp8_dtype, + device=device) + scale = torch.empty( + (div_up(out_features, + self.block_size), div_up(in_features, self.block_size)), + dtype=torch.float32, + device=device) + if bias: + bias = torch.empty((out_features, ), dtype=dtype, device=device) + else: + bias = None + return weight, scale, bias + + def update_weights(self): + """update weights.""" + weight, scale, bias = self.impl.update_weights(self.weight, self.scale, + self.bias) + weight = torch.nn.Parameter(weight, requires_grad=False) + self.weight.weight_loader = self.weight_loader + scale = torch.nn.Parameter(scale, requires_grad=False) + self.scale.weight_loader = self.weight_loader + if bias is not None: + bias = torch.nn.Parameter(bias, requires_grad=False) + self.bias.weight_loader = self.weight_loader + self.register_parameter('weight', weight) + self.register_parameter('scale', scale) + self.register_parameter('bias', bias) + + def forward(self, x): + """forward of blocked fp8 linear.""" + all_reduce = False if self.colwise else self.is_tp + all_reduce = all_reduce and self.all_reduce + if len(self.lora_adapters) == 0: + return self.impl.forward(x, self.weight, self.scale, self.bias, + all_reduce) + + out = self.impl.forward(x, self.weight, self.scale, self.bias, False) + for lora_adapter in self.lora_adapters.values(): + out = lora_adapter(x, out) + if all_reduce: + dist.all_reduce(out) + return out + + +class MergedBlockedF8Linear(BlockedF8Linear): + """merged blocked fp8 linear.""" + + def __init__(self, + in_features: int, + all_out_features: List[int], + bias: bool, + fp8_dtype: torch.dtype = torch.float8_e4m3fn, + replicate: Optional[List[bool]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + is_tp: bool = True, + out_names: Optional[List[int]] = None): + if replicate is None: + replicate = tuple(False for _ in all_out_features) + self.block_size = 128 + self.split_section = all_out_features + self.scale_split_section = [ + section // self.block_size for section in self.split_section + ] + all_out_features = self._update_all_out_features( + all_out_features, replicate) + self.all_out_features = all_out_features + self.replicate = replicate + if out_names is None: + out_names = torch.arange(len(self.all_out_features)).tolist() + assert len(out_names) == len(self.all_out_features) + self.out_names_map = dict( + (name, idx) for idx, name in enumerate(out_names)) + out_features = sum(all_out_features) + super().__init__(in_features, + out_features, + bias, + dtype, + device, + fp8_dtype=fp8_dtype, + colwise=True, + is_tp=is_tp) + self.weight.weight_loader = self.weight_loader + self.scale.weight_loader = self.weight_loader + self.weight.weight_spliter = self.weight_spliter + self.scale.weight_spliter = self.weight_spliter + if self.bias is not None: + self.bias.weight_loader = self.weight_loader + self.bias.weight_spliter = self.weight_spliter + + def _get_io_features(self, in_features: int, out_features: int, + colwise: bool): + """get io features.""" + return in_features, out_features + + def _update_all_out_features(self, all_out_features: List[int], + replicate: Optional[List[bool]]): + """update all out features.""" + world_size, rank = get_world_rank() + new_all_out_features = [] + for out_feat, rep in zip(all_out_features, replicate): + if rep: + new_all_out_features.append(out_feat) + new_out_feat = get_distribute_size(out_feat, world_size, rank) + new_all_out_features.append(new_out_feat) + return new_all_out_features + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, shard_id: Any): + """weight loader.""" + world_size, rank = get_world_rank() + shard_idx = self.out_names_map[shard_id] + if loaded_weight.dim() == 2 and loaded_weight.dtype == torch.float32: + all_out_features = [ + feats // self.block_size for feats in self.all_out_features + ] + param_w = param.data.split(all_out_features, 0)[shard_idx] + else: + param_w = param.data.split(self.all_out_features, 0)[shard_idx] + if not self.replicate[shard_idx]: + loaded_weight = loaded_weight.chunk(world_size, 0)[rank] + param_w.copy_(loaded_weight) + + def weight_spliter(self, loaded_weight: torch.Tensor): + """weight spliter.""" + if loaded_weight.dim() == 2 and loaded_weight.dtype == torch.float32: + return loaded_weight.split(self.scale_split_section, dim=0) + return loaded_weight.split(self.split_section, dim=0) + + def weight_spliter_lora_b(self, loaded_weight: torch.Tensor): + return loaded_weight.split(self.split_section, dim=0) + + class AwqLinear(nn.Module): """w4a16 linear.""" @@ -212,7 +438,7 @@ def __init__( self.out_features = out_features self.w_bit = w_bit self.group_size = group_size - self.elem_per_int = 32 // self.w_bit + self.elem_per_int = 32 // w_bit self.lora_adapters = nn.ModuleDict() self.is_tp = is_tp self.colwise = colwise @@ -363,12 +589,9 @@ def __init__(self, w_bit: int, group_size: int, bias: bool, - replicate: Optional[List[bool]] = None, device: Optional[torch.device] = None, is_tp: bool = True, out_names: Optional[List[int]] = None): - if replicate is None: - replicate = tuple(False for _ in all_out_features) self.split_section_s = all_out_features elem_per_int = 32 // w_bit @@ -377,9 +600,8 @@ def __init__(self, ] all_out_features = self._update_all_out_features( - all_out_features, w_bit, group_size, replicate) + all_out_features, w_bit, group_size) self.all_out_features = all_out_features - self.replicate = replicate if out_names is None: out_names = torch.arange(len(self.all_out_features)).tolist() assert len(out_names) == len(self.all_out_features) @@ -414,15 +636,12 @@ def _get_io_features(self, in_features: int, out_features: int, w_bit: int, return in_features, out_features def _update_all_out_features(self, all_out_features: List[int], w_bit: int, - group_size: int, - replicate: Optional[List[bool]]): + group_size: int): """update all out features.""" world_size, rank = get_world_rank() new_all_out_features = [] align = max(32 // w_bit, group_size) - for out_feat, rep in zip(all_out_features, replicate): - if rep: - new_all_out_features.append(out_feat) + for out_feat in all_out_features: new_out_feat = get_distribute_size(out_feat, world_size, rank, align) new_all_out_features.append(new_out_feat) @@ -433,14 +652,11 @@ def weight_loader(self, param: torch.nn.Parameter, """weight loader.""" world_size, rank = get_world_rank() shard_idx = self.out_names_map[shard_id] - if loaded_weight.dim() == 1: # bias align = max(self.elem_per_int, self.group_size) param_w = param.data.split(self.all_out_features, 0)[shard_idx] - if not self.replicate[shard_idx]: - weight = _chunk_align(loaded_weight, world_size, 0, - align)[rank] + weight = _chunk_align(loaded_weight, world_size, 0, align)[rank] param_w.copy_(weight) if param._weight_type in ['scales', 'bias']: @@ -456,8 +672,7 @@ def weight_loader(self, param: torch.nn.Parameter, ] param_w = param.data.split(quanted_out_feats, 1)[shard_idx] - if not self.replicate[shard_idx]: - weight = _chunk_align(loaded_weight, world_size, -1, align)[rank] + weight = _chunk_align(loaded_weight, world_size, -1, align)[rank] param_w.copy_(weight) def weight_spliter_wz(self, loaded_weight: torch.Tensor): @@ -480,45 +695,82 @@ def __init__(self, head_size_v: int, w_bit: int, group_size: int, - replicate_kv: bool = False, bias: bool = False, device: Optional[torch.device] = None, - is_tp: bool = True): + is_tp: bool = True, + num_replicate_kv_heads: int = 1): self.qkv_split_section_s = self._get_qkv_out_features( - num_q_heads, num_kv_heads, head_size, head_size_v) + num_q_heads, num_kv_heads, head_size, head_size_v, + num_replicate_kv_heads) elem_per_int = 32 // w_bit self.qkv_split_section_wz = [ size // elem_per_int for size in self.qkv_split_section_s ] num_q_heads, num_kv_heads = self._update_num_heads( - num_q_heads, num_kv_heads, replicate_kv) + num_q_heads, num_kv_heads) all_out_features = self._get_qkv_out_features(num_q_heads, num_kv_heads, head_size, head_size_v) - replicate = (False, replicate_kv, replicate_kv) out_names = ('q', 'k', 'v') self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads self.head_size = head_size self.head_size_v = head_size_v + self.num_replicate_kv_heads = num_replicate_kv_heads + super().__init__(in_features, all_out_features, w_bit=w_bit, group_size=group_size, bias=bias, - replicate=replicate, device=device, is_tp=is_tp, out_names=out_names) def _update_all_out_features(self, all_out_features: List[int], w_bit: int, - group_size: int, - replicate: Optional[List[bool]]): + group_size: int): """update all out features.""" return all_out_features + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, shard_id: Any): + """weight loader.""" + world_size, rank = get_world_rank() + chunk_size, chunk_idx = world_size, rank + shard_idx = self.out_names_map[shard_id] + + if self.num_replicate_kv_heads > 1 and shard_id in ['k', 'v']: + # update to duplicate k/v for tp_size > num_kv_heads + chunk_size = world_size // self.num_replicate_kv_heads + chunk_idx = rank // self.num_replicate_kv_heads + + if loaded_weight.dim() == 1: + # bias + align = max(self.elem_per_int, self.group_size) + param_w = param.data.split(self.all_out_features, 0)[shard_idx] + weight = _chunk_align(loaded_weight, chunk_size, 0, + align)[chunk_idx] + param_w.copy_(weight) + return + + if param._weight_type in ['scales', 'bias']: + # scales + align = max(self.elem_per_int, self.group_size) + param_w = param.data.split(self.all_out_features, -1)[shard_idx] + else: + # qweight or qzeros + align = max(self.elem_per_int, + self.group_size) // self.elem_per_int + quanted_out_feats = [ + feat // self.elem_per_int for feat in self.all_out_features + ] + param_w = param.data.split(quanted_out_feats, 1)[shard_idx] + + weight = _chunk_align(loaded_weight, chunk_size, -1, align)[chunk_idx] + param_w.copy_(weight) + def weight_spliter_wz(self, loaded_weight: torch.Tensor, layout: str = 'default'): @@ -566,17 +818,16 @@ def weight_spliter_lora_b(self, loaded_weight: torch.Tensor): class W8A8Linear(nn.Module): """w8a8 linear.""" - def __init__( - self, - in_features: int, - out_features: int, - bias: bool, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - colwise: bool = True, - is_tp: bool = False, - all_reduce: bool = True, - ): + def __init__(self, + in_features: int, + out_features: int, + bias: bool, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + colwise: bool = True, + is_tp: bool = False, + all_reduce: bool = True, + quant_dtype: Optional[torch.dtype] = torch.int8): super().__init__() if device is None: device = torch.device('cpu') @@ -586,10 +837,12 @@ def __init__( in_features, out_features = self._get_io_features( in_features, out_features, colwise) impl_builder = get_backend().get_layer_impl_builder(OpType.LinearW8A8) + self.quant_dtype = quant_dtype self.impl = impl_builder.build(in_features, out_features, bias is not None, - dtype=dtype) + dtype=dtype, + quant_dtype=quant_dtype) weight, scale, bias = self.create_weights(in_features, out_features, bias, dtype, device) weight = torch.nn.Parameter(weight, requires_grad=False) @@ -631,7 +884,9 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int, world_size: int): """weight loader for rowwise linear.""" - if loaded_weight.dim() == 2 and param.dtype == torch.int8: + if loaded_weight.dim() == 2 and param.dtype in (torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2): weight = loaded_weight.chunk(world_size, 1)[rank] return default_weight_loader(param, weight) elif loaded_weight.dim() == 2 and loaded_weight.size(1) == 1: @@ -661,7 +916,7 @@ def create_weights(self, in_features: int, out_features: int, bias: bool, dtype: torch.dtype, device: torch.device): """create weights.""" weight = torch.empty((out_features, in_features), - dtype=torch.int8, + dtype=self.quant_dtype, device=device) scale = torch.empty((out_features, 1), dtype=torch.float32, @@ -710,18 +965,14 @@ def __init__(self, in_features: int, all_out_features: List[int], bias: bool, - replicate: Optional[List[bool]] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = True, - out_names: Optional[List[int]] = None): - if replicate is None: - replicate = tuple(False for _ in all_out_features) + out_names: Optional[List[int]] = None, + quant_dtype: torch.dtype = torch.int8): self.split_section = all_out_features - all_out_features = self._update_all_out_features( - all_out_features, replicate) + all_out_features = self._update_all_out_features(all_out_features) self.all_out_features = all_out_features - self.replicate = replicate if out_names is None: out_names = torch.arange(len(self.all_out_features)).tolist() assert len(out_names) == len(self.all_out_features) @@ -734,7 +985,8 @@ def __init__(self, dtype, device, colwise=True, - is_tp=is_tp) + is_tp=is_tp, + quant_dtype=quant_dtype) self.weight.weight_loader = self.weight_loader self.scale.weight_loader = self.weight_loader self.weight.weight_spliter = self.weight_spliter @@ -748,14 +1000,11 @@ def _get_io_features(self, in_features: int, out_features: int, """get io features.""" return in_features, out_features - def _update_all_out_features(self, all_out_features: List[int], - replicate: Optional[List[bool]]): + def _update_all_out_features(self, all_out_features: List[int]): """update all out features.""" world_size, rank = get_world_rank() new_all_out_features = [] - for out_feat, rep in zip(all_out_features, replicate): - if rep: - new_all_out_features.append(out_feat) + for out_feat in all_out_features: new_out_feat = get_distribute_size(out_feat, world_size, rank) new_all_out_features.append(new_out_feat) return new_all_out_features @@ -766,8 +1015,7 @@ def weight_loader(self, param: torch.nn.Parameter, world_size, rank = get_world_rank() shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] - if not self.replicate[shard_idx]: - loaded_weight = loaded_weight.chunk(world_size, 0)[rank] + loaded_weight = loaded_weight.chunk(world_size, 0)[rank] param_w.copy_(loaded_weight) def weight_spliter(self, loaded_weight: torch.Tensor): @@ -787,38 +1035,60 @@ def __init__(self, num_kv_heads: int, head_size: int, head_size_v: int, - replicate_kv: bool = False, bias: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, - is_tp: bool = True): + is_tp: bool = True, + num_replicate_kv_heads: int = 1, + quant_dtype: torch.dtype = torch.int8): + self.qkv_split_section = self._get_qkv_out_features( - num_q_heads, num_kv_heads, head_size, head_size_v) + num_q_heads, num_kv_heads, head_size, head_size_v, + num_replicate_kv_heads) num_q_heads, num_kv_heads = self._update_num_heads( - num_q_heads, num_kv_heads, replicate_kv) + num_q_heads, num_kv_heads) all_out_features = self._get_qkv_out_features(num_q_heads, num_kv_heads, head_size, head_size_v) - replicate = (False, replicate_kv, replicate_kv) out_names = ('q', 'k', 'v') self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads self.head_size = head_size self.head_size_v = head_size_v + self.num_replicate_kv_heads = num_replicate_kv_heads super().__init__(in_features, all_out_features, bias=bias, - replicate=replicate, dtype=dtype, device=device, is_tp=is_tp, - out_names=out_names) + out_names=out_names, + quant_dtype=quant_dtype) - def _update_all_out_features(self, all_out_features: List[int], - replicate: Optional[List[bool]]): + def _update_all_out_features(self, all_out_features: List[int]): """update all out features.""" return all_out_features + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, shard_id: Any): + """weight loader.""" + _, rank = get_world_rank() + shard_idx = self.out_names_map[shard_id] + param_w = param.data.split(self.all_out_features, 0)[shard_idx] + num_head = self.num_q_heads if shard_id == 'q' \ + else self.num_kv_heads + head_dim = self.head_size if shard_id in ['q', 'k'] \ + else self.head_size_v + # update to duplicate k/v for tp_size > num_kv_heads + rank_idx = rank if shard_id == 'q' \ + else rank // self.num_replicate_kv_heads + sec_start = rank_idx * num_head * head_dim + sec_len = num_head * head_dim + loaded_weight = loaded_weight.narrow(dim=0, + start=sec_start, + length=sec_len) + param_w.copy_(loaded_weight) + def weight_spliter(self, loaded_weight: torch.Tensor, layout: str = 'default'): @@ -986,18 +1256,13 @@ def __init__(self, in_features: int, all_out_features: List[int], bias: bool, - replicate: Optional[List[bool]] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = True, out_names: Optional[List[int]] = None): - if replicate is None: - replicate = tuple(False for _ in all_out_features) self.split_section = all_out_features - all_out_features = self._update_all_out_features( - all_out_features, replicate) + all_out_features = self._update_all_out_features(all_out_features) self.all_out_features = all_out_features - self.replicate = replicate if out_names is None: out_names = torch.arange(len(self.all_out_features)).tolist() assert len(out_names) == len(self.all_out_features) @@ -1022,14 +1287,11 @@ def _get_io_features(self, in_features: int, out_features: int, """get io features.""" return in_features, out_features - def _update_all_out_features(self, all_out_features: List[int], - replicate: Optional[List[bool]]): + def _update_all_out_features(self, all_out_features: List[int]): """update all out features.""" world_size, rank = get_world_rank() new_all_out_features = [] - for out_feat, rep in zip(all_out_features, replicate): - if rep: - new_all_out_features.append(out_feat) + for out_feat in all_out_features: new_out_feat = get_distribute_size(out_feat, world_size, rank) new_all_out_features.append(new_out_feat) return new_all_out_features @@ -1040,8 +1302,7 @@ def weight_loader(self, param: torch.nn.Parameter, world_size, rank = get_world_rank() shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] - if not self.replicate[shard_idx]: - loaded_weight = loaded_weight.chunk(world_size, 0)[rank] + loaded_weight = loaded_weight.chunk(world_size, 0)[rank] param_w.copy_(loaded_weight) def weight_spliter(self, loaded_weight: torch.Tensor): @@ -1061,35 +1322,36 @@ def __init__(self, num_kv_heads: int, head_size: int, head_size_v: int, - replicate_kv: bool = False, bias: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, - is_tp: bool = True): + is_tp: bool = True, + num_replicate_kv_heads: int = 1): + self.qkv_split_section = self._get_qkv_out_features( - num_q_heads, num_kv_heads, head_size, head_size_v) + num_q_heads, num_kv_heads, head_size, head_size_v, + num_replicate_kv_heads) num_q_heads, num_kv_heads = self._update_num_heads( - num_q_heads, num_kv_heads, replicate_kv) + num_q_heads, num_kv_heads) all_out_features = self._get_qkv_out_features(num_q_heads, num_kv_heads, head_size, head_size_v) - replicate = (False, replicate_kv, replicate_kv) out_names = ('q', 'k', 'v') self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads self.head_size = head_size self.head_size_v = head_size_v + self.num_replicate_kv_heads = num_replicate_kv_heads + super().__init__(in_features, all_out_features, bias=bias, - replicate=replicate, dtype=dtype, device=device, is_tp=is_tp, out_names=out_names) - def _update_all_out_features(self, all_out_features: List[int], - replicate: Optional[List[bool]]): + def _update_all_out_features(self, all_out_features: List[int]): """update all out features.""" return all_out_features @@ -1097,15 +1359,20 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """weight loader.""" world_size, rank = get_world_rank() + chunk_size, chunk_idx = world_size, rank shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] - if not self.replicate[shard_idx]: - if shard_idx in [0, 1]: - loaded_weight = _chunk_align(loaded_weight, world_size, 0, - self.head_size)[rank] - if shard_idx == 2: - loaded_weight = _chunk_align(loaded_weight, world_size, 0, - self.head_size_v)[rank] + + if self.num_replicate_kv_heads > 1 and shard_id in ['k', 'v']: + # update to duplicate k/v for tp_size > num_kv_heads + chunk_size = world_size // self.num_replicate_kv_heads + chunk_idx = rank // self.num_replicate_kv_heads + if shard_idx in [0, 1]: + loaded_weight = _chunk_align(loaded_weight, chunk_size, 0, + self.head_size)[chunk_idx] + elif shard_idx == 2: + loaded_weight = _chunk_align(loaded_weight, chunk_size, 0, + self.head_size_v)[chunk_idx] param_w.copy_(loaded_weight) def weight_spliter(self, @@ -1161,6 +1428,10 @@ def build_linear(in_features: int, ) quant_method = quant_config['quant_method'] + quant_dtype = torch.int8 + if 'quant_dtype' in quant_config: + quant_dtype = eval('torch.' + quant_config['quant_dtype']) + if quant_method == 'awq': w_bit = quant_config.get('bits', 4) group_size = quant_config.get('group_size', 128) @@ -1176,10 +1447,28 @@ def build_linear(in_features: int, all_reduce=all_reduce, ) if quant_method == 'smooth_quant': - return W8A8Linear( + return W8A8Linear(in_features, + out_features, + bias=bias, + dtype=dtype, + device=device, + colwise=colwise, + is_tp=is_tp, + all_reduce=all_reduce, + quant_dtype=quant_dtype) + elif quant_method == 'fp8': + fmt = quant_config.get('fmt', 'e4m3') + if fmt == 'e4m3': + fp8_dtype = torch.float8_e4m3fn + elif fmt == 'e5m2': + fp8_dtype = torch.float8_e5m2 + else: + raise TypeError(f'Unsupported fp8 fmt: {fmt}') + return BlockedF8Linear( in_features, out_features, bias=bias, + fp8_dtype=fp8_dtype, dtype=dtype, device=device, colwise=colwise, @@ -1260,6 +1549,10 @@ def build_merged_colwise_linear( ) quant_method = quant_config['quant_method'] + quant_dtype = torch.int8 + if 'quant_dtype' in quant_config: + quant_dtype = eval('torch.' + quant_config['quant_dtype']) + if quant_method == 'awq': w_bit = quant_config.get('bits', 4) group_size = quant_config.get('group_size', 128) @@ -1273,10 +1566,27 @@ def build_merged_colwise_linear( is_tp=is_tp, ) if quant_method == 'smooth_quant': - return MergedW8A8Linear( + return MergedW8A8Linear(in_features=in_features, + all_out_features=all_out_features, + bias=bias, + dtype=dtype, + device=device, + is_tp=is_tp, + out_names=out_names, + quant_dtype=quant_dtype) + elif quant_method == 'fp8': + fmt = quant_config.get('fmt', 'e4m3') + if fmt == 'e4m3': + fp8_dtype = torch.float8_e4m3fn + elif fmt == 'e5m2': + fp8_dtype = torch.float8_e5m2 + else: + raise TypeError(f'Unsupported fp8 fmt: {fmt}') + return MergedBlockedF8Linear( in_features=in_features, all_out_features=all_out_features, bias=bias, + fp8_dtype=fp8_dtype, dtype=dtype, device=device, is_tp=is_tp, @@ -1291,12 +1601,12 @@ def build_qkv_proj(in_features: int, num_kv_heads: int, head_size: int, head_size_v: int = None, - replicate_kv: bool = False, bias: bool = False, quant_config: Any = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, - is_tp: bool = True): + is_tp: bool = True, + num_replicate_kv_heads: int = 1): """build qkv proj.""" if is_tp: world_size, _ = get_world_rank() @@ -1306,48 +1616,47 @@ def build_qkv_proj(in_features: int, head_size_v = head_size if quant_config is None: - return QKVBaseLinear( - in_features=in_features, - num_q_heads=num_q_heads, - num_kv_heads=num_kv_heads, - head_size=head_size, - head_size_v=head_size_v, - replicate_kv=replicate_kv, - bias=bias, - dtype=dtype, - device=device, - is_tp=is_tp, - ) + return QKVBaseLinear(in_features=in_features, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + head_size_v=head_size_v, + bias=bias, + dtype=dtype, + device=device, + is_tp=is_tp, + num_replicate_kv_heads=num_replicate_kv_heads) quant_method = quant_config['quant_method'] + quant_dtype = torch.int8 + if 'quant_dtype' in quant_config: + quant_dtype = eval('torch.' + quant_config['quant_dtype']) + if quant_method == 'awq': w_bit = quant_config.get('bits', 4) group_size = quant_config.get('group_size', 128) - return QKVAwqLinear( - in_features=in_features, - num_q_heads=num_q_heads, - num_kv_heads=num_kv_heads, - head_size=head_size, - head_size_v=head_size_v, - replicate_kv=replicate_kv, - w_bit=w_bit, - group_size=group_size, - bias=bias, - device=device, - is_tp=is_tp, - ) + return QKVAwqLinear(in_features=in_features, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + head_size_v=head_size_v, + w_bit=w_bit, + group_size=group_size, + bias=bias, + device=device, + is_tp=is_tp, + num_replicate_kv_heads=num_replicate_kv_heads) if quant_method == 'smooth_quant': - return QKVW8A8Linear( - in_features=in_features, - num_q_heads=num_q_heads, - num_kv_heads=num_kv_heads, - head_size=head_size, - head_size_v=head_size_v, - replicate_kv=replicate_kv, - bias=bias, - dtype=dtype, - device=device, - is_tp=is_tp, - ) + return QKVW8A8Linear(in_features=in_features, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + head_size_v=head_size_v, + bias=bias, + dtype=dtype, + device=device, + is_tp=is_tp, + num_replicate_kv_heads=num_replicate_kv_heads, + quant_dtype=quant_dtype) else: raise RuntimeError(f'Unsupported quant method: {quant_method}') diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index 47176335c4..4921825c9a 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional +from typing import Any, List, Optional import torch import torch.distributed as dist @@ -8,6 +8,7 @@ from lmdeploy.pytorch.distributed import get_world_rank from ..backends import OpType, get_backend +from .utils import div_up class SoftmaxTopK(nn.Module): @@ -24,6 +25,102 @@ def forward(self, x: torch.Tensor): return self.impl.forward(x) +def create_mlp_weights(hidden_dim: int, ffn_dim: int, num_experts: int, + dtype: torch.dtype, device: torch.device): + """create weights.""" + gate_up_weights = torch.empty((num_experts, ffn_dim * 2, hidden_dim), + dtype=dtype, + device=device) + down_weights = torch.empty((num_experts, hidden_dim, ffn_dim), + dtype=dtype, + device=device) + return gate_up_weights, down_weights + + +def _update_args(hidden_dim: int, ffn_dim: int): + """update args.""" + world_size, _ = get_world_rank() + assert ffn_dim % world_size == 0 + ffn_dim = ffn_dim // world_size + return hidden_dim, ffn_dim + + +class LinearWeights(nn.Module): + """fused moe linear weights.""" + + def __init__(self, + num_experts: int, + in_features: int, + out_features: int, + weight_type: str, + dtype: torch.dtype, + device: torch.device, + expert_list: List[int] = None, + ep: bool = False): + super().__init__() + weight = torch.empty((num_experts, out_features, in_features), + dtype=dtype, + device=device) + weight = torch.nn.Parameter(weight, requires_grad=False) + self.register_parameter('weight', weight) + self.ep = ep + self.expert_list = expert_list + self.weight_type = weight_type + self.half_out = out_features // 2 + + if self.ep: + self.expert_map = dict( + (eid, idx) for idx, eid in enumerate(expert_list)) + self.weight.weight_loader = self.weight_loader_ep + else: + self.weight.weight_loader = self.weight_loader_tp + + def update_weight(self, weight: torch.Tensor): + """update weight.""" + weight_loader = self.weight.weight_loader + weight = torch.nn.Parameter(weight, requires_grad=False) + weight.weight_loader = weight_loader + self.register_parameter('weight', weight) + + def weight_loader_tp(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, expert_id: int, + shard_id: str): + """weight loader.""" + world_size, rank = get_world_rank() + if shard_id == 'gate': + param_data = param.data[expert_id, :self.half_out] + weight = loaded_weight.chunk(world_size, dim=0)[rank] + elif shard_id == 'up': + param_data = param.data[expert_id, self.half_out:] + weight = loaded_weight.chunk(world_size, dim=0)[rank] + elif shard_id == 'down': + param_data = param.data[expert_id] + weight = loaded_weight.chunk(world_size, dim=1)[rank] + else: + raise RuntimeError(f'Unknown shard_id: {shard_id}') + param_data.copy_(weight) + + def weight_loader_ep(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, expert_id: int, + shard_id: str): + """weight loader.""" + expert_list = self.expert_list + if expert_id not in expert_list: + return + + expert_map = self.expert_map + param_id = expert_map[expert_id] + if shard_id == 'gate': + param_data = param.data[param_id, :self.half_out] + elif shard_id == 'up': + param_data = param.data[param_id, self.half_out:] + elif shard_id == 'down': + param_data = param.data[param_id] + else: + raise RuntimeError(f'Unknown shard_id: {shard_id}') + param_data.copy_(loaded_weight) + + class FusedMoE(nn.Module): """fused moe.""" @@ -46,42 +143,33 @@ def __init__(self, impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoE) self.impl = impl_builder.build(top_k, num_experts, renormalize) - self.expert_list = None - self.expert_map = None enable_ep = enable_ep and self.impl.support_ep() if enable_ep: world_size, rank = get_world_rank() expert_list = self.impl.ep_expert_list(world_size, rank) - self.expert_list = expert_list - self.expert_map = dict( - (eid, idx) for idx, eid in enumerate(expert_list)) num_experts = len(expert_list) - gate_up_weights, down_weights = self.create_weights(hidden_dim, - ffn_dim, - num_experts, - dtype=dtype, - device=device) - else: - hidden_dim, ffn_dim = self._update_args(hidden_dim, ffn_dim) - gate_up_weights, down_weights = self.create_weights(hidden_dim, - ffn_dim, - num_experts, - dtype=dtype, - device=device) - gate_up_weights = torch.nn.Parameter(gate_up_weights, - requires_grad=False) - down_weights = torch.nn.Parameter(down_weights, requires_grad=False) - gate_up_weights._weight_type = 'gate_up_weights' - down_weights._weight_type = 'down_weights' - self.register_parameter('gate_up_weights', gate_up_weights) - self.register_parameter('down_weights', down_weights) - - if enable_ep: - gate_up_weights.weight_loader = self.weight_loader_ep - down_weights.weight_loader = self.weight_loader_ep else: - gate_up_weights.weight_loader = self.weight_loader_tp - down_weights.weight_loader = self.weight_loader_tp + hidden_dim, ffn_dim = _update_args(hidden_dim, ffn_dim) + expert_list = None + self.expert_list = expert_list + self.gate_up = LinearWeights(num_experts, + hidden_dim, + ffn_dim * 2, + weight_type='gate_up', + dtype=dtype, + device=device, + expert_list=expert_list, + ep=enable_ep) + self.down = LinearWeights( + num_experts, + ffn_dim, + hidden_dim, + weight_type='down', + dtype=dtype, + device=device, + expert_list=expert_list, + ep=enable_ep, + ) self.hidden_dim = hidden_dim self.ffn_dim = ffn_dim @@ -93,83 +181,375 @@ def __init__(self, all_reduce = False self.all_reduce = all_reduce - def _update_args(self, hidden_dim: int, ffn_dim: int): - """update args.""" - world_size, _ = get_world_rank() - assert ffn_dim % world_size == 0 - ffn_dim = ffn_dim // world_size - return hidden_dim, ffn_dim - - def create_weights(self, hidden_dim: int, ffn_dim: int, num_experts: int, - dtype: torch.dtype, device: torch.device): - """create weights.""" - gate_up_weights = torch.empty((num_experts, ffn_dim * 2, hidden_dim), - dtype=dtype, - device=device) - down_weights = torch.empty((num_experts, hidden_dim, ffn_dim), - dtype=dtype, - device=device) - return gate_up_weights, down_weights - def update_weights(self): """update weights.""" - gateup_loader = self.gate_up_weights.weight_loader - down_loader = self.down_weights.weight_loader gate_up_weights, down_weights = self.impl.update_weights( - self.gate_up_weights, self.down_weights) - gate_up_weights = torch.nn.Parameter(gate_up_weights, - requires_grad=False) - down_weights = torch.nn.Parameter(down_weights, requires_grad=False) - gate_up_weights.weight_loader = gateup_loader - down_weights.weight_loader = down_loader - gate_up_weights._weight_type = 'gate_up_weights' - down_weights._weight_type = 'down_weights' - self.register_parameter('gate_up_weights', gate_up_weights) - self.register_parameter('down_weights', down_weights) + self.gate_up.weight, self.down.weight) + self.gate_up.update_weight(gate_up_weights) + self.down.update_weight(down_weights) - def weight_loader_tp(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, expert_id: int, - shard_id: str): - """weight loader.""" + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.LongTensor): + ret = self.impl.forward(hidden_states, topk_weights, topk_ids, + self.gate_up.weight, self.down.weight, + self.expert_list) + if self.all_reduce: + dist.all_reduce(ret) + return ret + + +class LinearWeightsW8A8(LinearWeights): + """fused moe linear w8a8 weights.""" + + def __init__(self, + num_experts: int, + in_features: int, + out_features: int, + weight_type: str, + device: torch.device, + expert_list: List[int] = None, + ep: bool = False): + super().__init__( + num_experts=num_experts, + in_features=in_features, + out_features=out_features, + weight_type=weight_type, + dtype=torch.int8, + device=device, + expert_list=expert_list, + ep=ep, + ) + scale = torch.empty((num_experts, out_features, 1), + dtype=torch.float32, + device=device) + scale = torch.nn.Parameter(scale, requires_grad=False) + self.register_parameter('scale', scale) + + if self.ep: + self.scale.weight_loader = self.weight_loader_ep + else: + self.scale.weight_loader = self.weight_loader_scale_tp + + def update_weight(self, weight: torch.Tensor, scale: torch.Tensor): + """update weight.""" + super().update_weight(weight=weight) + weight_loader = self.scale.weight_loader + scale = torch.nn.Parameter(scale, requires_grad=False) + scale.weight_loader = weight_loader + self.register_parameter('scale', scale) + + def weight_loader_scale_tp(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, expert_id: int, + shard_id: str): + """weight loader scale tp.""" world_size, rank = get_world_rank() if shard_id == 'gate': - param_data = param.data[expert_id, :self.ffn_dim] + param_data = param.data[expert_id, :self.half_out] weight = loaded_weight.chunk(world_size, dim=0)[rank] elif shard_id == 'up': - param_data = param.data[expert_id, self.ffn_dim:] + param_data = param.data[expert_id, self.half_out:] weight = loaded_weight.chunk(world_size, dim=0)[rank] elif shard_id == 'down': param_data = param.data[expert_id] - weight = loaded_weight.chunk(world_size, dim=1)[rank] + weight = loaded_weight else: raise RuntimeError(f'Unknown shard_id: {shard_id}') param_data.copy_(weight) - def weight_loader_ep(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, expert_id: int, - shard_id: str): - """weight loader.""" - expert_list = self.expert_list - if expert_id not in expert_list: - return - expert_map = self.expert_map - param_id = expert_map[expert_id] +class FusedMoEW8A8(nn.Module): + """fused moe w8a8.""" + + def __init__(self, + hidden_dim: int, + ffn_dim: int, + num_experts: int, + top_k: int, + renormalize: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + all_reduce: bool = True, + enable_ep: bool = False): + super().__init__() + if device is None: + device = torch.device('cpu') + dtype = torch.float16 if dtype is None else dtype + + impl_builder = get_backend().get_layer_impl_builder( + OpType.FusedMoEW8A8) + self.impl = impl_builder.build(top_k, num_experts, renormalize, dtype) + + enable_ep = enable_ep and self.impl.support_ep() + if enable_ep: + world_size, rank = get_world_rank() + expert_list = self.impl.ep_expert_list(world_size, rank) + num_experts = len(expert_list) + else: + hidden_dim, ffn_dim = _update_args(hidden_dim, ffn_dim) + expert_list = None + self.expert_list = expert_list + + self.gate_up = LinearWeightsW8A8(num_experts, + hidden_dim, + ffn_dim * 2, + weight_type='gate_up', + device=device, + expert_list=expert_list, + ep=enable_ep) + self.down = LinearWeightsW8A8( + num_experts, + ffn_dim, + hidden_dim, + weight_type='down', + device=device, + expert_list=expert_list, + ep=enable_ep, + ) + + self.hidden_dim = hidden_dim + self.ffn_dim = ffn_dim + self.num_experts = num_experts + self.dtype = dtype + self.device = device + world_size, _ = get_world_rank() + if world_size == 1: + all_reduce = False + self.all_reduce = all_reduce + + def update_weights(self): + """update weights.""" + (gate_up_weights, down_weights, gate_up_scale, + down_scale) = self.impl.update_weights(self.gate_up.weight, + self.down.weight, + self.gate_up.scale, + self.down.scale) + self.gate_up.update_weight(gate_up_weights, gate_up_scale) + self.down.update_weight(down_weights, down_scale) + + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.LongTensor): + ret = self.impl.forward(hidden_states, topk_weights, topk_ids, + self.gate_up.weight, self.gate_up.scale, + self.down.weight, self.down.scale, + self.expert_list) + if self.all_reduce: + dist.all_reduce(ret) + return ret + + +class LinearWeightsBlockedF8(LinearWeights): + """fused moe linear blocked fp8 weights.""" + + def __init__(self, + num_experts: int, + in_features: int, + out_features: int, + weight_type: str, + block_size: int, + dtype: torch.dtype, + device: torch.device, + expert_list: List[int] = None, + ep: bool = False): + super().__init__( + num_experts=num_experts, + in_features=in_features, + out_features=out_features, + weight_type=weight_type, + dtype=dtype, + device=device, + expert_list=expert_list, + ep=ep, + ) + self.block_size = block_size + scale = torch.empty((num_experts, div_up( + out_features, block_size), div_up(in_features, block_size)), + dtype=torch.float32, + device=device) + scale = torch.nn.Parameter(scale, requires_grad=False) + self.register_parameter('scale', scale) + + if self.ep: + self.scale.weight_loader = self.weight_loader_ep + else: + self.scale.weight_loader = self.weight_loader_scale_tp + + def update_weight(self, weight: torch.Tensor, scale: torch.Tensor): + """update weight.""" + super().update_weight(weight=weight) + weight_loader = self.scale.weight_loader + scale = torch.nn.Parameter(scale, requires_grad=False) + scale.weight_loader = weight_loader + self.register_parameter('scale', scale) + + def weight_loader_scale_tp(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, expert_id: int, + shard_id: str): + """weight loader scale tp.""" + world_size, rank = get_world_rank() + block_size = self.block_size + half_out = self.half_out // block_size if shard_id == 'gate': - param_data = param.data[param_id, :self.ffn_dim] + param_data = param.data[expert_id, :half_out] + weight = loaded_weight.chunk(world_size, dim=0)[rank] elif shard_id == 'up': - param_data = param.data[param_id, self.ffn_dim:] + param_data = param.data[expert_id, half_out:] + weight = loaded_weight.chunk(world_size, dim=0)[rank] elif shard_id == 'down': - param_data = param.data[param_id] + param_data = param.data[expert_id] + weight = loaded_weight.chunk(world_size, dim=1)[rank] else: raise RuntimeError(f'Unknown shard_id: {shard_id}') - param_data.copy_(loaded_weight) + param_data.copy_(weight) + + +class FusedMoEBlockedF8(nn.Module): + """fused moe blocked f8.""" + + def __init__(self, + hidden_dim: int, + ffn_dim: int, + num_experts: int, + top_k: int, + renormalize: bool = False, + fp8_dtype: torch.dtype = torch.float8_e4m3fn, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + all_reduce: bool = True, + enable_ep: bool = False): + super().__init__() + if device is None: + device = torch.device('cpu') + dtype = torch.float16 if dtype is None else dtype + self.block_size = 128 + impl_builder = get_backend().get_layer_impl_builder( + OpType.FusedMoEBlockedF8) + self.impl = impl_builder.build(top_k, + num_experts, + renormalize, + block_size=self.block_size, + out_dtype=dtype) + + enable_ep = enable_ep and self.impl.support_ep() + if enable_ep: + world_size, rank = get_world_rank() + expert_list = self.impl.ep_expert_list(world_size, rank) + num_experts = len(expert_list) + else: + hidden_dim, ffn_dim = _update_args(hidden_dim, ffn_dim) + expert_list = None + self.expert_list = expert_list + + self.gate_up = LinearWeightsBlockedF8(num_experts, + hidden_dim, + ffn_dim * 2, + weight_type='gate_up', + block_size=self.block_size, + dtype=fp8_dtype, + device=device, + expert_list=expert_list, + ep=enable_ep) + self.down = LinearWeightsBlockedF8( + num_experts, + ffn_dim, + hidden_dim, + weight_type='down', + block_size=self.block_size, + dtype=fp8_dtype, + device=device, + expert_list=expert_list, + ep=enable_ep, + ) + + self.hidden_dim = hidden_dim + self.ffn_dim = ffn_dim + self.num_experts = num_experts + self.dtype = dtype + self.device = device + world_size, _ = get_world_rank() + if world_size == 1: + all_reduce = False + self.all_reduce = all_reduce + + def update_weights(self): + """update weights.""" + (gate_up_weights, down_weights, gate_up_scale, + down_scale) = self.impl.update_weights(self.gate_up.weight, + self.down.weight, + self.gate_up.scale, + self.down.scale) + self.gate_up.update_weight(gate_up_weights, gate_up_scale) + self.down.update_weight(down_weights, down_scale) def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.LongTensor): ret = self.impl.forward(hidden_states, topk_weights, topk_ids, - self.gate_up_weights, self.down_weights, + self.gate_up.weight, self.gate_up.scale, + self.down.weight, self.down.scale, self.expert_list) if self.all_reduce: dist.all_reduce(ret) return ret + + +def build_fused_moe( + hidden_dim: int, + ffn_dim: int, + num_experts: int, + top_k: int, + renormalize: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + all_reduce: bool = True, + enable_ep: bool = False, + quant_config: Any = None, +): + """fused moe builder.""" + + if quant_config is None: + return FusedMoE( + hidden_dim=hidden_dim, + ffn_dim=ffn_dim, + num_experts=num_experts, + top_k=top_k, + renormalize=renormalize, + dtype=dtype, + device=device, + all_reduce=all_reduce, + enable_ep=enable_ep, + ) + + quant_method = quant_config['quant_method'] + if quant_method == 'smooth_quant': + return FusedMoEW8A8( + hidden_dim=hidden_dim, + ffn_dim=ffn_dim, + num_experts=num_experts, + top_k=top_k, + renormalize=renormalize, + dtype=dtype, + device=device, + all_reduce=all_reduce, + enable_ep=enable_ep, + ) + elif quant_method == 'fp8': + fmt = quant_config.get('fmt', 'e4m3') + if fmt == 'e4m3': + fp8_dtype = torch.float8_e4m3fn + elif fmt == 'e5m2': + fp8_dtype = torch.float8_e5m2 + else: + raise TypeError(f'Unsupported fp8 fmt: {fmt}') + return FusedMoEBlockedF8( + hidden_dim=hidden_dim, + ffn_dim=ffn_dim, + num_experts=num_experts, + top_k=top_k, + renormalize=renormalize, + fp8_dtype=fp8_dtype, + dtype=dtype, + device=device, + all_reduce=all_reduce, + enable_ep=enable_ep, + ) + else: + raise RuntimeError(f'Unsupported quant method: {quant_method}') diff --git a/lmdeploy/pytorch/nn/norm.py b/lmdeploy/pytorch/nn/norm.py index ef244ff73f..7e2c820399 100644 --- a/lmdeploy/pytorch/nn/norm.py +++ b/lmdeploy/pytorch/nn/norm.py @@ -4,19 +4,23 @@ import torch from torch import nn +from lmdeploy.pytorch.distributed import get_world_rank + from ..backends import OpType, get_backend +from .utils import chunk_aligned, get_distribute_size def _is_w8a8(quant_config: Any): """is w8a8.""" - if quant_config is None: - return False - else: + quant_dtype = None + w8a8_flag = False + if quant_config is not None: quant_method = quant_config['quant_method'] - if quant_method == 'w8a8': - return True - else: - return False + if quant_method == 'smooth_quant': + w8a8_flag = True + quant_dtype = quant_config.get('quant_dtype', 'int8') + quant_dtype = eval(f'torch.{quant_dtype}') + return w8a8_flag, quant_dtype class RMSNorm(nn.Module): @@ -27,16 +31,44 @@ def __init__(self, eps: float = 1e-6, dtype: torch.dtype = None, device: torch.device = None, - quant_config: Any = None): + quant_config: Any = None, + tp: bool = False, + align: int = 1): super().__init__() backend = get_backend() - if _is_w8a8(quant_config): + + w8a8_flag, quant_dtype = _is_w8a8(quant_config) + if w8a8_flag: builder = backend.get_layer_impl_builder(OpType.RMSNormW8A8) else: builder = backend.get_layer_impl_builder(OpType.RMSNorm) + + if tp: + world_size, rank = get_world_rank() + hidden_size = get_distribute_size(hidden_size, + world_size, + rank, + align=align) + self.register_parameter('weight', self.create_weight(hidden_size, dtype, device)) - self.impl = builder.build(hidden_size, eps) + if w8a8_flag: + self.impl = builder.build(hidden_size, + eps, + quant_dtype=quant_dtype) + else: + self.impl = builder.build(hidden_size, eps) + + if tp: + self.weight.weight_loader = self.weight_loader + self.align = align + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + """weight loader.""" + world_size, rank = get_world_rank() + loaded_weight = chunk_aligned(loaded_weight, world_size, 0, + self.align)[rank] + param.copy_(loaded_weight) @staticmethod def create_weight(hidden_size: int, diff --git a/lmdeploy/pytorch/nn/utils.py b/lmdeploy/pytorch/nn/utils.py index 3289f858a7..085b12c3e9 100644 --- a/lmdeploy/pytorch/nn/utils.py +++ b/lmdeploy/pytorch/nn/utils.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch + + def div_up(a: int, b: int): """div up.""" return (a + b - 1) // b @@ -11,7 +14,26 @@ def get_distribute_size(feature_size: int, """update feature size.""" assert feature_size % align == 0 aligned_size = feature_size // align - align_per_rank = div_up(aligned_size, world_size) - prev_feats = align_per_rank * rank - updated_aligned_size = min(align_per_rank, aligned_size - prev_feats) + # try to make every rank has same amount of feats + updated_aligned_size = aligned_size // world_size + # if there are still some remain, given them to + # each rank + if rank < aligned_size % world_size: + updated_aligned_size += 1 return updated_aligned_size * align + + +def chunk_aligned(weight: torch.Tensor, chunks: int, dim: int, align: int): + """chunk aligned.""" + if align == 1: + return weight.chunk(chunks, dim=dim) + size = weight.size(dim) + assert size % align == 0 + aligned_size = size // align + + # try best to evenly split chunks + align_per_chunk = aligned_size // chunks + remain = aligned_size % chunks + sections = [align_per_chunk + int(c < remain) for c in range(chunks)] + sections = [sec * align for sec in sections] + return weight.split(sections, dim=dim) diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index e28e375965..0d901d75a3 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -274,11 +274,19 @@ def has_unfinished(self): return self.has_running() or self.has_waiting() def has_running(self): - return self.seq_manager.num_sequences(MessageStatus.RUNNING) > 0 + return self.num_running() > 0 def has_waiting(self): - return self.seq_manager.num_sequences(MessageStatus.WAITING) > 0 + return self.num_waiting() > 0 def get_block_tables(self, seqs: SeqList): """get block table of the sequences.""" return [self.block_manager.get_block_table(seq) for seq in seqs] + + def num_running(self): + """num running.""" + return self.seq_manager.num_sequences(MessageStatus.RUNNING) + + def num_waiting(self): + """num waiting.""" + return self.seq_manager.num_sequences(MessageStatus.WAITING) diff --git a/lmdeploy/pytorch/supported_models.py b/lmdeploy/pytorch/supported_models.py index 7fa568651b..67452f78e3 100644 --- a/lmdeploy/pytorch/supported_models.py +++ b/lmdeploy/pytorch/supported_models.py @@ -47,9 +47,9 @@ # cogvlm-chat CogVLMForCausalLM=True, # llava - LlavaLlamaForCausalLM=True, + LlavaLlamaForCausalLM=False, # llava mistral - LlavaMistralForCausalLM=True, + LlavaMistralForCausalLM=False, # deepseekvl MultiModalityCausalLM=False, # StarCoder2 diff --git a/lmdeploy/pytorch/tools/make_inputs.py b/lmdeploy/pytorch/tools/make_inputs.py index f2d23830b7..053e7d0918 100644 --- a/lmdeploy/pytorch/tools/make_inputs.py +++ b/lmdeploy/pytorch/tools/make_inputs.py @@ -135,6 +135,7 @@ def __fill_kv_caches(kv_caches, past_key_values, block_offsets): return StepContext.new( inputs=model_inputs, + model_config=model_config, world_size=world_size, kv_caches=kv_caches, ) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 69bead8906..86c0936de2 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -231,9 +231,10 @@ def __call__(self, """Inference a batch of prompts. Args: - prompts (List[str] | str | List[Dict] | List[Dict]): a batch of - prompts. It accepts: string prompt, a list of string prompts, - a chat history in OpenAI format or a list of chat history. + prompts (List[str] | str | List[Dict] | List[List[Dict]]]): a + batch of prompts. It accepts: string prompt, a list of string + prompts, a chat history in OpenAI format or a list of chat + history. gen_config (GenerationConfig | None): a instance of GenerationConfig. Default to None. do_preprocess (bool): whether pre-process the messages. Default to @@ -305,9 +306,10 @@ def batch_infer(self, """Inference a batch of prompts. Args: - prompts (List[str] | str | List[Dict] | List[Dict]): a batch of - prompts. It accepts: string prompt, a list of string prompts, - a chat history in OpenAI format or a list of chat history. + prompts (List[str] | str | List[Dict] | List[List[Dict]]]): a + batch of prompts. It accepts: string prompt, a list of string + prompts, a chat history in OpenAI format or a list of chat + history. gen_config (GenerationConfig | None): a instance of or a list of GenerationConfig. Default to None. do_preprocess (bool): whether pre-process the messages. Default to @@ -382,9 +384,10 @@ def stream_infer( """Inference a batch of prompts with stream mode. Args: - prompts (List[str] | str | List[Dict] | List[Dict]): a batch of - prompts. It accepts: string prompt, a list of string prompts, - a chat history in OpenAI format or a list of chat history. + prompts (List[str] | str | List[Dict] | List[List[Dict]]]):a + batch of prompts. It accepts: string prompt, a list of string + prompts, a chat history in OpenAI format or a list of chat + history. gen_config (GenerationConfig | None): a instance of or a list of GenerationConfig. Default to None. do_preprocess (bool): whether pre-process the messages. Default to @@ -510,8 +513,8 @@ async def generate( if gen_config.stop_token_ids is None: gen_config.stop_token_ids = self.stop_words if not gen_config.do_sample: - logger.warn(f'GenerationConfig: {gen_config}') - logger.warn( + logger.warning(f'GenerationConfig: {gen_config}') + logger.warning( 'Since v0.6.0, lmdeploy add `do_sample` in ' 'GenerationConfig. It defaults to False, meaning greedy ' 'decoding. Please set `do_sample=True` if sampling ' diff --git a/lmdeploy/serve/gradio/vl.py b/lmdeploy/serve/gradio/vl.py index 103bcc5889..bf8ee87e68 100644 --- a/lmdeploy/serve/gradio/vl.py +++ b/lmdeploy/serve/gradio/vl.py @@ -70,8 +70,6 @@ def run_local(model_path: str, **kwargs): from lmdeploy.serve.vl_async_engine import VLAsyncEngine - if isinstance(backend_config, PytorchEngineConfig): - backend_config.thread_safe = True vision_config = VisionConfig(thread_safe=True) engine = VLAsyncEngine(model_path=model_path, model_name=model_name, @@ -115,10 +113,13 @@ def chat(chatbot, session, max_new_tokens, top_p, top_k, temperature): else: prompt = history[-1][0][0] images = history[-1][0][1:] - prompt = (prompt, images) - - logger.info('prompt: ' + str(prompt)) - prompt = engine.vl_prompt_template.prompt_to_messages(prompt) + # convert prompt into GPT4V format + prompt = [ + dict(role='user', content=[dict(type='text', text=prompt)]) + ] + for image in images: + prompt[0]['content'].append( + dict(type='image_data', image_data=dict(data=image))) t0 = time.perf_counter() inputs = _run_until_complete( engine._get_prompt_input(prompt, True, sequence_start, '')) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index f515e49d2e..eb424a0829 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -509,7 +509,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: for call_info in call_info_list ] except Exception as e: - logger.error(f'Exception: {e}') + logger.error(f'Failed to parse {text}. Exception: {e}.') return create_error_response( HTTPStatus.BAD_REQUEST, 'Failed to parse fc related info to json format!') @@ -946,6 +946,20 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return JSONResponse(ret) +def handle_torchrun(): + """To disable mmengine logging logic when using torchrun.""" + + def dummy_get_device_id(): + return 0 + + if int(os.environ.get('LOCAL_RANK', -1)) > 0: + from lmdeploy.vl.model.utils import _set_func + + # the replacement can't be recovered + _set_func('mmengine.logging.logger._get_device_id', + dummy_get_device_id) + + @router.on_event('startup') async def startup_event(): if VariableInterface.proxy_url is None: @@ -1071,8 +1085,8 @@ def serve(model_path: str, ssl_certfile = os.environ['SSL_CERTFILE'] http_or_https = 'https' + handle_torchrun() _, pipeline_class = get_task(model_path) - VariableInterface.async_engine = pipeline_class( model_path=model_path, speculative_model=speculative_model, diff --git a/lmdeploy/serve/proxy/constants.py b/lmdeploy/serve/proxy/constants.py index 88d86a3e33..5bf6e67659 100644 --- a/lmdeploy/serve/proxy/constants.py +++ b/lmdeploy/serve/proxy/constants.py @@ -2,8 +2,8 @@ import enum -LATENCY_DEEQUE_LEN = 15 -API_TIMEOUT_LEN = 100 +LATENCY_DEQUE_LEN = 15 +API_READ_TIMEOUT = 100 class Strategy(enum.Enum): diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index 5f05930bd0..392ede3267 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import asyncio import copy import json import os @@ -18,14 +19,15 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, Field +from requests.exceptions import RequestException from lmdeploy.serve.openai.api_server import (check_api_key, create_error_response) from lmdeploy.serve.openai.protocol import ( # noqa: E501 ChatCompletionRequest, CompletionRequest, ModelCard, ModelList, ModelPermission) -from lmdeploy.serve.proxy.constants import (API_TIMEOUT_LEN, - LATENCY_DEEQUE_LEN, ErrorCodes, +from lmdeploy.serve.proxy.constants import (API_READ_TIMEOUT, + LATENCY_DEQUE_LEN, ErrorCodes, Strategy, err_msg) from lmdeploy.utils import get_logger @@ -36,7 +38,7 @@ class Status(BaseModel): """Status protocol consists of models' information.""" models: Optional[List[str]] = Field(default=[], examples=[[]]) unfinished: int = 0 - latency: Deque = Field(default=deque(maxlen=LATENCY_DEEQUE_LEN), + latency: Deque = Field(default=deque(maxlen=LATENCY_DEQUE_LEN), examples=[[]]) speed: Optional[int] = Field(default=None, examples=[None]) @@ -87,6 +89,9 @@ def __init__(self, with open(self.config_path, 'r') as config_file: self.nodes = yaml.safe_load(config_file)['nodes'] for url, status in self.nodes.items(): + latency = deque(status.get('latency', []), + maxlen=LATENCY_DEQUE_LEN) + status['latency'] = latency status = Status(**status) self.nodes[url] = status self.heart_beat_thread = threading.Thread(target=heart_beat_controller, @@ -99,7 +104,7 @@ def update_config_file(self): nodes = copy.deepcopy(self.nodes) for url, status in nodes.items(): nodes[url] = status.model_dump() - nodes[url]['latency'] = list(status.latency) + nodes[url]['latency'] = list(status.latency)[-LATENCY_DEQUE_LEN:] with open(self.config_path, 'w') as config_file: # update cfg yml yaml.dump(dict(nodes=nodes), config_file) @@ -149,7 +154,8 @@ def remove_stale_nodes_by_expiration(self): to_be_deleted.append(node_url) for node_url in to_be_deleted: self.remove(node_url) - logger.info(f'Removed node_url: {node_url}') + logger.info(f'Removed node_url: {node_url} ' + 'due to heart beat expiration') @property def model_list(self): @@ -251,7 +257,7 @@ def handle_unavailable_model(self, model_name): Args: model_name (str): the model in the request. """ - logger.info(f'no model name: {model_name}') + logger.warning(f'no model name: {model_name}') ret = { 'error_code': ErrorCodes.MODEL_NOT_FOUND, 'text': err_msg[ErrorCodes.MODEL_NOT_FOUND], @@ -260,51 +266,54 @@ def handle_unavailable_model(self, model_name): def handle_api_timeout(self, node_url): """Handle the api time out.""" - logger.info(f'api timeout: {node_url}') + logger.warning(f'api timeout: {node_url}') ret = { - 'error_code': ErrorCodes.API_TIMEOUT, + 'error_code': ErrorCodes.API_TIMEOUT.value, 'text': err_msg[ErrorCodes.API_TIMEOUT], } return json.dumps(ret).encode() + b'\n' - def stream_generate(self, request: Dict, node_url: str, node_path: str): + def stream_generate(self, request: Dict, node_url: str, endpoint: str): """Return a generator to handle the input request. Args: request (Dict): the input request. node_url (str): the node url. - node_path (str): the node path. Such as `/v1/chat/completions`. + endpoint (str): the endpoint. Such as `/v1/chat/completions`. """ try: response = requests.post( - node_url + node_path, + node_url + endpoint, json=request, - stream=request['stream'], - timeout=API_TIMEOUT_LEN, + stream=True, + timeout=(5, API_READ_TIMEOUT), ) for chunk in response.iter_lines(decode_unicode=False, delimiter=b'\n'): if chunk: yield chunk + b'\n\n' - except requests.exceptions.RequestException as e: # noqa + except (Exception, GeneratorExit, RequestException) as e: # noqa + logger.error(f'catched an exception: {e}') + # exception happened, reduce unfinished num yield self.handle_api_timeout(node_url) - async def generate(self, request: Dict, node_url: str, node_path: str): + async def generate(self, request: Dict, node_url: str, endpoint: str): """Return a the response of the input request. Args: request (Dict): the input request. node_url (str): the node url. - node_path (str): the node path. Such as `/v1/chat/completions`. + endpoint (str): the endpoint. Such as `/v1/chat/completions`. """ try: import httpx async with httpx.AsyncClient() as client: - response = await client.post(node_url + node_path, + response = await client.post(node_url + endpoint, json=request, - timeout=API_TIMEOUT_LEN) + timeout=API_READ_TIMEOUT) return response.text - except requests.exceptions.RequestException as e: # noqa + except (Exception, GeneratorExit, RequestException, asyncio.CancelledError) as e: # noqa # yapf: disable + logger.error(f'catched an exception: {e}') return self.handle_api_timeout(node_url) def pre_call(self, node_url): @@ -381,7 +390,11 @@ def add_node(node: Node, raw_request: Request = None): RPM or other metric. All the values of nodes should be the same metric. """ try: - node_manager.add(node.url, node.status) + res = node_manager.add(node.url, node.status) + if res is not None: + logger.error(f'add node {node.url} failed, {res}') + return res + logger.info(f'add node {node.url} successfully') return 'Added successfully' except: # noqa return 'Failed to add, please check the input url.' @@ -392,8 +405,10 @@ def remove_node(node_url: str): """Show available models.""" try: node_manager.remove(node_url) + logger.info(f'delete node {node_url} successfully') return 'Deleted successfully' except: # noqa + logger.error(f'delete node {node_url} failed.') return 'Failed to delete, please check the input url.' @@ -407,28 +422,50 @@ async def chat_completions_v1(request: ChatCompletionRequest, The request should be a JSON object with the following fields: - model: model name. Available from /v1/models. - - messages: string prompt or chat history in OpenAI format. A example - for chat history is `[{"role": "user", "content":"knock knock"}]`. + - messages: string prompt or chat history in OpenAI format. Chat history + example: `[{"role": "user", "content": "hi"}]`. - temperature (float): to modulate the next token probability - top_p (float): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. - n (int): How many chat completion choices to generate for each input - message. Only support one here. + message. **Only support one here**. - stream: whether to stream the results or not. Default to false. - - max_tokens (int): output token nums + - max_tokens (int | None): output token nums. Default to None. - repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty - stop (str | List[str] | None): To stop generating further tokens. Only accept stop words that's encoded to one token idex. + - response_format (Dict | None): Only pytorch backend support formatting + response. Examples: `{"type": "json_schema", "json_schema": {"name": + "test","schema": {"properties": {"name": {"type": "string"}}, + "required": ["name"], "type": "object"}}}` + or `{"type": "regex_schema", "regex_schema": "call me [A-Za-z]{1,10}"}` + - logit_bias (Dict): Bias to logits. Only supported in pytorch engine. + - tools (List): A list of tools the model may call. Currently, only + internlm2 functions are supported as a tool. Use this to specify a + list of functions for which the model can generate JSON inputs. + - tool_choice (str | object): Controls which (if any) tool is called by + the model. `none` means the model will not call any tool and instead + generates a message. Specifying a particular tool via {"type": + "function", "function": {"name": "my_function"}} forces the model to + call that tool. `auto` or `required` will put all the tools information + to the model. Additional arguments supported by LMDeploy: + - top_k (int): The number of the highest probability vocabulary + tokens to keep for top-k-filtering - ignore_eos (bool): indicator for ignoring eos - - session_id (int): if not specified, will set random value + - skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be True. + - min_new_tokens (int): To generate at least numbers of tokens. + - min_p (float): Minimum token probability, which will be scaled by the + probability of the most likely token. It must be a value between + 0 and 1. Typical values are in the 0.01-0.2 range, comparably + selective as setting `top_p` in the 0.99-0.8 range (use the + opposite of normal `top_p` values) Currently we do not support the following features: - - function_call (Users should implement this by themselves) - - logit_bias (not supported yet) - presence_penalty (replaced with repetition_penalty) - frequency_penalty (replaced with repetition_penalty) """ @@ -439,6 +476,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, if not node_url: return node_manager.handle_unavailable_model(request.model) + logger.info(f'A request is dispatched to {node_url}') request_dict = request.model_dump() start = node_manager.pre_call(node_url) if request.stream is True: @@ -465,13 +503,13 @@ async def completions_v1(request: CompletionRequest, - model (str): model name. Available from /v1/models. - prompt (str): the input prompt. - suffix (str): The suffix that comes after a completion of inserted text. - - max_tokens (int): output token nums + - max_tokens (int): output token nums. Default to 16. - temperature (float): to modulate the next token probability - top_p (float): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. - n (int): How many chat completion choices to generate for each input - message. Only support one here. + message. **Only support one here**. - stream: whether to stream the results or not. Default to false. - repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty @@ -481,7 +519,8 @@ async def completions_v1(request: CompletionRequest, Additional arguments supported by LMDeploy: - ignore_eos (bool): indicator for ignoring eos - - session_id (int): if not specified, will set random value + - skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be True. - top_k (int): The number of the highest probability vocabulary tokens to keep for top-k-filtering @@ -497,6 +536,7 @@ async def completions_v1(request: CompletionRequest, if not node_url: return node_manager.handle_unavailable_model(request.model) + logger.info(f'A request is dispatched to {node_url}') request_dict = request.model_dump() start = node_manager.pre_call(node_url) if request.stream is True: @@ -517,6 +557,7 @@ def proxy(server_name: str = '0.0.0.0', 'min_observed_latency'] = 'min_expected_latency', api_keys: Optional[Union[List[str], str]] = None, ssl: bool = False, + log_level: str = 'INFO', **kwargs): """To launch the proxy server. @@ -540,6 +581,7 @@ def proxy(server_name: str = '0.0.0.0', if ssl: ssl_keyfile = os.environ['SSL_KEYFILE'] ssl_certfile = os.environ['SSL_CERTFILE'] + logger.setLevel(log_level) uvicorn.run(app=app, host=server_name, port=server_port, diff --git a/lmdeploy/serve/vl_async_engine.py b/lmdeploy/serve/vl_async_engine.py index c293cd71c8..becf1b76fb 100644 --- a/lmdeploy/serve/vl_async_engine.py +++ b/lmdeploy/serve/vl_async_engine.py @@ -1,148 +1,208 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Union +import asyncio +from typing import Dict, List, Literal, Optional, Tuple, Union -import numpy as np +import PIL +from lmdeploy.messages import (PytorchEngineConfig, TurbomindEngineConfig, + VisionConfig) from lmdeploy.pytorch.check_env import try_import_deeplink from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.utils import get_logger -from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX, IMAGE_TOKEN from lmdeploy.vl.engine import ImageEncoder -from lmdeploy.vl.templates import VLPromptType, get_vl_prompt_template +from lmdeploy.vl.utils import load_image logger = get_logger('lmdeploy') +VLPromptType = Union[str, Tuple[str, PIL.Image.Image], + Tuple[str, List[PIL.Image.Image]]] + class VLAsyncEngine(AsyncEngine): """Visual Language Async inference engine.""" - def __init__(self, model_path: str, **kwargs) -> None: - vision_config = kwargs.pop('vision_config', None) - backend_config = kwargs.get('backend_config', None) - if kwargs.get('backend', '') == 'pytorch': + def __init__(self, + model_path: str, + backend: Literal['turbomind', 'pytorch'] = 'turbomind', + backend_config: Optional[Union[TurbomindEngineConfig, + PytorchEngineConfig]] = None, + vision_config: Optional[VisionConfig] = None, + **kwargs) -> None: + if backend == 'pytorch': try_import_deeplink(backend_config.device_type) self.vl_encoder = ImageEncoder(model_path, + backend, vision_config, backend_config=backend_config) - super().__init__(model_path, **kwargs) + super().__init__(model_path, + backend=backend, + backend_config=backend_config, + **kwargs) if self.model_name == 'base': raise RuntimeError( 'please specify chat template as guided in https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html#set-chat-template' # noqa: E501 ) - self.vl_prompt_template = get_vl_prompt_template( - model_path, self.chat_template, self.model_name) - def _convert_prompts(self, + @classmethod + def _convert_prompts(cls, prompts: Union[VLPromptType, List[Dict], List[VLPromptType], List[List[Dict]]]): - """convert prompts to openai format.""" + """convert prompts to openai GPT4V format.""" if isinstance(prompts, str) or isinstance(prompts, tuple): - _prompts = self.vl_prompt_template.prompt_to_messages(prompts) + _prompts = cls.prompt_to_messages(prompts) elif isinstance(prompts[0], tuple) or isinstance(prompts[0], str): - _prompts = [ - self.vl_prompt_template.prompt_to_messages(x) for x in prompts - ] + _prompts = [cls.prompt_to_messages(x) for x in prompts] else: _prompts = prompts return _prompts async def _get_prompt_input(self, - prompt: Dict, + messages: Union[str, List[Dict]], do_preprocess: bool, sequence_start: bool, adapter_name: str, tools: Optional[List[object]] = None, **kwargs): - """get input_ids, embeddings and offsets.""" - if do_preprocess: - decorated = self.vl_prompt_template.messages2prompt( - prompt, sequence_start) - else: - decorated = prompt - segs = decorated.split(IMAGE_TOKEN) - - results = {} - input_ids = [] - from lmdeploy.vl.templates import (MllamaTempateWrapper, - MolmoChatTemplateWrapper, - Qwen2VLChatTemplateWrapper) - ranges = None - grid_thws = None - if len(segs) > 1: - # yapf: disable - images_with_kwargs = await self.vl_prompt_template.async_collect_pil_images(prompt) # noqa: E501 - # yapf: enable - features = [] - if len(images_with_kwargs) > 0: - images, image_kwargs = list(zip(*images_with_kwargs)) - features = await self.vl_encoder.async_infer( - images, image_kwargs) - - from lmdeploy.vl.templates import MiniCPMVTempateWrapper - if isinstance(self.vl_prompt_template, MiniCPMVTempateWrapper): - decorated, features = self.vl_prompt_template.update_image_token( # noqa: E501 - decorated, features) - segs = decorated.split(IMAGE_TOKEN) - - if isinstance(self.vl_prompt_template, - Qwen2VLChatTemplateWrapper): - grid_thws = [x['grid_thw'] for x in features] - features = [x['embeddings'] for x in features] - - if isinstance(self.vl_prompt_template, MllamaTempateWrapper): - # llama3.2 just encode <|image|> and inference - decorated = decorated.replace(IMAGE_TOKEN, '<|image|>') - input_ids = self.tokenizer.encode(decorated, - add_bos=sequence_start) - results['input_ids'] = input_ids - results['prompt'] = decorated - assert len(features) - results['cross_attention_states'] = features[0] - return results - - if isinstance(self.vl_prompt_template, - MolmoChatTemplateWrapper): - return features[0] - - features = [x.cpu().numpy() for x in features] - input_ids = [] - begins = [] - ends = [] - if len(segs) != len(features) + 1: - logger.error( - f'the number of {IMAGE_TOKEN} is not equal ' - f'to input images, {len(segs) - 1} vs {len(features)}') - features = features[:len(segs) - 1] - for i, seg in enumerate(segs): - if i > 0 and i <= len(features): - image_dim = features[i - 1].shape[0] - begins.append(len(input_ids)) - ends.append(begins[-1] + image_dim) - input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim) - seg_ids = self.tokenizer.encode(seg, - add_bos=((i == 0) - and sequence_start)) - input_ids.extend(seg_ids) - ranges = np.stack([begins, ends], axis=1).tolist() - results['input_embeddings'] = features or None - results['input_embedding_ranges'] = ranges or None + """process messages and return the required data for the inference + engines. + + Refer to pytorch.engine.EngineInstance.async_stream_infer and + turbomind.TurboMindInstance.async_stream_infer for the argument + specification. + """ + if isinstance(messages, str): + return await super()._get_prompt_input(messages, do_preprocess, + sequence_start, + adapter_name, tools, + **kwargs) + elif isinstance(messages, List): + has_multimodal_input = any( + isinstance(message['content'], list) and any( + item['type'] in ['image_url', 'image_data'] + for item in message['content']) for message in messages) + if not has_multimodal_input: + return await super()._get_prompt_input(messages, do_preprocess, + sequence_start, + adapter_name, tools, + **kwargs) else: - input_ids = self.tokenizer.encode(decorated, - add_bos=sequence_start) - - if isinstance(self.vl_prompt_template, Qwen2VLChatTemplateWrapper): - # TODO: refactor _get_prompt_input function - mrope_position_ids, mrope_position_delta = \ - self.vl_prompt_template.get_mrope_info( - len(input_ids), grid_thws=grid_thws, - embedding_ranges=ranges) - results['mrope_position_ids'] = mrope_position_ids - results['mrope_position_delta'] = mrope_position_delta - - results['input_ids'] = input_ids - results['prompt'] = decorated + raise RuntimeError(f'unsupported messages {messages}') + + messages = await self.async_convert_to_pil_images(messages) + results = await self.vl_encoder.preprocess(messages) + if self.backend == 'turbomind': + # for tm engine, this module perform vision embedding after image + # preprocessing. It utilizes the hf model's vision embeddings + # functions and returns the input_ids, input_embeddings, + # embedding_ranges and so on. All the returned values are passed + # to tm engine for token generation + results = await self.vl_encoder.async_infer(results) + results = await self.vl_encoder.wrap_for_turbomind( + results, self.chat_template, self.tokenizer, sequence_start) + elif self.backend == 'pytorch': + # for pt engine, this module only conduct the image preprocessing + # It leaves the vision embedding to the pt engine + results = await self.vl_encoder.wrap_for_pytorch( + results, self.chat_template, self.tokenizer, sequence_start) return results + @classmethod + async def async_convert_to_pil_images(cls, + messages: List[Dict]) -> List[Dict]: + """Scan the provided messages to find image URLs or base64-encoded + image data. Loads the images into Pillow image objects. + + Args: + messages (List[Dict]): a user request of GPT4V message format + """ + if isinstance(messages, Dict): + messages = [messages] + assert isinstance(messages, List) + + out_messages = [None] * len(messages) + + def _inner_call(i, in_messages, out_messages): + role = in_messages[i]['role'] + content = in_messages[i]['content'] + assert role in ['system', 'user', 'assistant'], \ + f'unsupported role "{role}"' + if role != 'user' or isinstance(content, str): + # the content is a user's prompt or an assistant's prompt, + # returning it directly + out_messages[i] = in_messages[i] + return + # the role is a user and the content is a list, in which there + # might be image_url or image_data + assert isinstance(content, List) + message = dict(role=role, content=[]) + for item in content: + # image url or base64-encoded image data + if item['type'] == 'image_url': + """ + convert the following item: + { + 'type': 'image_url', + 'image_url': { + 'url': 'image url or base64-encoded image data', + 'key': 'value' # parameters used in image processing + ... + } + } + to: + { + 'type': 'image', + 'image': Pillow.Image, + 'key': 'value' # parameters used in image processing + ... + } + """ # noqa + data = item['image_url'].copy() + try: + url = data.pop('url') + image = load_image(url) + data.update(type='image', image=image) + message['content'].append(data) + except KeyError: + logger.error(f'invalid format {message}') + elif item['type'] == 'image_data': + """ + convert the following item: + { + 'type': 'image_data', + 'image_data': { + 'data': Pillow.Image, + 'key': 'value' # parameters used in image processing + ... + } + } + to: + { + 'type': 'image', + 'image': Pillow.Image, + 'key': 'value' # parameters used in image processing + ... + } + """ # noqa + data = item['image_data'].copy() + try: + image = data.pop('data') + data.update(type='image', image=image) + message['content'].append(data) + except KeyError: + logger.error(f'invalid format {message}') + elif item['type'] == 'text': + message['content'].append(item) + else: + logger.error(f'unexpected content type {message}') + out_messages[i] = message + + await asyncio.gather(*[ + asyncio.get_event_loop().run_in_executor(None, _inner_call, i, + messages, out_messages) + for i in range(len(messages)) + ]) + return out_messages + def batch_infer(self, prompts: Union[VLPromptType, List[Dict], List[VLPromptType], List[List[Dict]]], **kwargs): @@ -173,3 +233,46 @@ def chat(self, prompts: VLPromptType, **kwargs): last_round = sess.history[-1] sess.history[-1] = (prompts, last_round[-1]) return sess + + @classmethod + def prompt_to_messages(cls, prompt: VLPromptType): + """convert prompt to GTP4V format.""" + messages = { + 'role': 'user', + 'content': [{ + 'type': 'text', + 'text': '', + }] + } + if isinstance(prompt, str): + messages['content'][0]['text'] = prompt + else: + prompt, images = prompt + if not isinstance(images, list): + images = [images] + messages['content'][0]['text'] = prompt + for image in images: + # 'image_url': means url or local path to image. + # 'image_data': means PIL.Image.Image object. + if isinstance(image, str): + image = load_image(image) + item = { + 'type': 'image_data', + 'image_data': { + 'data': image + } + } + elif isinstance(image, PIL.Image.Image): + item = { + 'type': 'image_data', + 'image_data': { + 'data': image + } + } + else: + raise ValueError( + 'image should be a str(url/path) or PIL.Image.Image') + + messages['content'].append(item) + + return [messages] diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index 77f0bc8dc8..176c3191f4 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -6,7 +6,7 @@ import fire import torch -from lmdeploy.archs import get_model_arch +from lmdeploy.archs import get_model_arch, search_nested_config from lmdeploy.messages import TurbomindEngineConfig from lmdeploy.model import MODELS, best_match_model from lmdeploy.utils import get_logger, get_model @@ -129,16 +129,17 @@ def get_output_model_registered_name_and_config(model_path: str, ] else 'float16' elif dtype in ['float16', 'bfloat16']: if weight_type == 'int4': - logger.warn(f'The model {model_path} is a quantized model, so the ' - f'specified data type {dtype} is ignored') + logger.warning( + f'The model {model_path} is a quantized model, so the ' + f'specified data type {dtype} is ignored') else: weight_type = dtype else: assert 0, f'unsupported specified data type {dtype}' if weight_type == 'bfloat16' and not is_bf16_supported(): - logger.warn('data type fallback to float16 since ' - 'torch.cuda.is_bf16_supported is False') + logger.warning('data type fallback to float16 since ' + 'torch.cuda.is_bf16_supported is False') weight_type = 'float16' config.model_config.model_arch = model_arch config.model_config.weight_type = weight_type @@ -174,23 +175,6 @@ def pack_model_repository(workspace_path: str): dst=osp.join(model_repo_dir, 'postprocessing')) -def find_quantization_config(nested, target_key): - if isinstance(nested, dict): - for key, value in nested.items(): - if key == target_key: - return value - if isinstance(value, (dict, list)): - result = find_quantization_config(value, target_key) - if result is not None: - return result - elif isinstance(nested, list): - for item in nested: - result = find_quantization_config(item, target_key) - if result is not None: - return result - return None - - def get_tm_model(model_path, model_name, chat_template_name, @@ -213,8 +197,7 @@ def get_tm_model(model_path, If it is None, the turbomind model won't be saved """ _, cfg = get_model_arch(model_path) - quant_config = find_quantization_config(cfg.to_dict(), - 'quantization_config') + quant_config = search_nested_config(cfg.to_dict(), 'quantization_config') if quant_config: quant_method = quant_config.get('quant_method') _group_size = int(quant_config.get('group_size', 0)) diff --git a/lmdeploy/turbomind/deploy/module.py b/lmdeploy/turbomind/deploy/module.py index 52497175ef..1754161ff5 100644 --- a/lmdeploy/turbomind/deploy/module.py +++ b/lmdeploy/turbomind/deploy/module.py @@ -191,7 +191,7 @@ def __init__(self, model: BaseOutputModel): self.attn_bias = model.model_config.attn_bias def _reorder_and_merge(self, qkvo): - q, k, v, o = map(transpose, qkvo) + q, k, v, o = qkvo # reorder output dim for tm's rotary embedding layout if self.model.permute_qk: q = permute_v2(q, self.head_dim) @@ -202,6 +202,27 @@ def _reorder_and_merge(self, qkvo): o = torch.zeros_like(q) return qkv, o + def _repeat_kv(self, qkvo, kind: str): + """replicate kv.""" + q, k, v, o = qkvo + head_dim = self.model.model_config.size_per_head + hidden_dim = self.model.model_config.hidden_units + + def _repeat(x): + dim = hidden_dim if kind != 'bias' else 1 + x = x.reshape(dim, -1, head_dim) + x = x.repeat(1, 1, self.model.repeat_kv) + x = x.reshape(dim, -1) + return x + + k, v = map(_repeat, (k, v)) + if kind == 'bias': + if o is None: + o = torch.zeros(hidden_dim, dtype=q.dtype, device=q.device) + q, k, v, o = map(torch.squeeze, (q, k, v, o)) + + return (q, k, v, o) + def _export(self, idx: int, qkvo, kind: str, pack_fn, **kwargs): if all(x is None for x in qkvo): return @@ -209,6 +230,9 @@ def _export(self, idx: int, qkvo, kind: str, pack_fn, **kwargs): if is_lora_a: qkv, o = map(transpose, qkvo) else: + qkvo = tuple(map(transpose, qkvo)) + if self.model.repeat_kv: + qkvo = self._repeat_kv(qkvo, kind) qkv, o = self._reorder_and_merge(qkvo) self.model.save_split(pack_fn(qkv), self._attn.format(idx, 'w_qkv', kind), diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index f2c981bb24..7ea1a84f35 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -78,6 +78,17 @@ def __init__(self, self.model_config.expert_inter_size = _pad_inter_size( self.model_config.expert_inter_size, self.model_config.group_size, self.tensor_para_size) + + # head_num is divisble by tp but kv_head_num is not + # and tp is divisble by kv_head_num + assert self.model_config.head_num % self.tensor_para_size == 0 + self.repeat_kv = 0 + if (self.tensor_para_size > self.model_config.kv_head_num and + self.tensor_para_size % self.model_config.kv_head_num == 0): + self.repeat_kv = (self.tensor_para_size // + self.model_config.kv_head_num) + self.model_config.kv_head_num = self.tensor_para_size + self.model_config.verify() assert self.model_config.kv_head_num % self.tensor_para_size == 0 diff --git a/lmdeploy/turbomind/supported_models.py b/lmdeploy/turbomind/supported_models.py index 11e99edfa0..2b9c5156ed 100644 --- a/lmdeploy/turbomind/supported_models.py +++ b/lmdeploy/turbomind/supported_models.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from lmdeploy.archs import get_model_arch +from lmdeploy.archs import get_model_arch, search_nested_config from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') @@ -80,7 +80,12 @@ def _is_head_dim_supported(cfg): if os.path.exists(triton_model_path): support_by_turbomind = True else: + arch, cfg = get_model_arch(model_path) + quant_method = search_nested_config(cfg.to_dict(), 'quant_method') + if quant_method and quant_method in ['smooth_quant']: + # tm hasn't support quantized models by applying smoothquant + return False if arch in SUPPORTED_ARCHS.keys(): support_by_turbomind = True diff --git a/lmdeploy/version.py b/lmdeploy/version.py index d9f4307a78..0b4b8a5379 100644 --- a/lmdeploy/version.py +++ b/lmdeploy/version.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple -__version__ = '0.6.3' +__version__ = '0.6.5' short_version = __version__ diff --git a/lmdeploy/vl/engine.py b/lmdeploy/vl/engine.py index 124fd537c6..7d490b2b77 100644 --- a/lmdeploy/vl/engine.py +++ b/lmdeploy/vl/engine.py @@ -1,13 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio -import inspect -import queue -import time -from threading import Thread +from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Optional, Union import torch -from PIL.Image import Image from lmdeploy.messages import (PytorchEngineConfig, TurbomindEngineConfig, VisionConfig) @@ -27,169 +23,95 @@ def _raise_exception_on_finish(task: asyncio.Task) -> None: raise e -class Record: - """Batching manager.""" - - def __init__(self, thread_safe): - self.thread_safe = thread_safe - self.number = [] - self.waiting = [] - self.kwargs = [] - self.done = [] - self.res_que = [] - self.total = 0 - - def enqueue(self, images: List[Image], kwargs: List[Dict], - que: Union[queue.Queue, asyncio.Queue]): - """add ith request to manager.""" - self.number.append(len(images)) - self.waiting.extend(images) - self.kwargs.extend(kwargs) - self.res_que.append(que) - self.total += len(images) - self.log('received', len(images)) - - def dequeue(self, max_batch_size): - """try to dequeue max batch size images.""" - inputs = self.waiting[:max_batch_size] - kwargs = self.kwargs[:max_batch_size] - self.waiting = self.waiting[max_batch_size:] - self.kwargs = self.kwargs[max_batch_size:] - self.total -= len(inputs) - self.log('process', len(inputs)) - return inputs, kwargs - - def notify(self): - """set result if request i is finished.""" - if len(self.number) == 0 or self.number[0] > len(self.done): - return False - num_images = self.number.pop(0) - outputs = self.done[:num_images] - self.done = self.done[num_images:] - que = self.res_que.pop(0) - self.log('done', num_images) - if self.thread_safe: - que._loop.call_soon_threadsafe(que.put_nowait, outputs) - else: - que.put_nowait(outputs) - return True - - def log(self, task: str, num: int): - logger.info(f'ImageEncoder {task} {num} images, ' - f'left {self.total} images.') - - class ImageEncoder: """Image encoder.""" - def __init__(self, - model_path: str, - vision_config: VisionConfig = None, - backend_config: Optional[Union[TurbomindEngineConfig, - PytorchEngineConfig]] = None): - self.model = load_vl_model(model_path, backend_config=backend_config) + def __init__( + self, + model_path: str, + backend: str, + vision_config: VisionConfig = None, + backend_config: Optional[Union[TurbomindEngineConfig, + PytorchEngineConfig]] = None, + ): + self.model = load_vl_model(model_path, + backend, + backend_config=backend_config) if vision_config is None: vision_config = VisionConfig() self.vision_config = vision_config self.max_batch_size = vision_config.max_batch_size + self.executor = ThreadPoolExecutor(max_workers=1) torch.cuda.empty_cache() - self._que: asyncio.Queue = None - self._loop_task: asyncio.Task = None - if vision_config.thread_safe: - self._create_thread_safe_task() - - def _create_thread_safe_task(self): - """thread safe loop task.""" - self._loop = asyncio.new_event_loop() - def _work_thread(): - asyncio.set_event_loop(self._loop) - self._que = asyncio.Queue() - self._loop.run_until_complete(self._forward_loop()) - - thread = Thread(target=_work_thread, daemon=True) - thread.start() - self._loop_thread = thread - - def _create_event_loop_task(self): - """event loop task.""" - task = asyncio.get_event_loop().create_task(self._forward_loop()) - self._loop_task = task - self._loop = task.get_loop() - - @property - def req_que(self): - if self.vision_config.thread_safe: - return self._que - if self._que is None: - self._que = asyncio.Queue() - if self._loop_task is None: - self._create_event_loop_task() - if asyncio.get_event_loop() != self._loop: - raise RuntimeError('Current event loop is different from' - ' the one bound to loop task!') - return self._que - - async def _forward_loop(self): - """working loop to process images.""" - logger.info('start ImageEncoder._forward_loop') - record = Record(self.vision_config.thread_safe) - while True: - while record.total == 0 or (self._que.qsize() and - record.total < self.max_batch_size): - while self._que.qsize() == 0: - await asyncio.sleep(0.01) - item = await self._que.get() - record.enqueue(item[0], item[1], item[2]) - inputs, kwargs = record.dequeue(self.max_batch_size) - future = asyncio.get_event_loop().run_in_executor( - None, self.forward, inputs, kwargs) - future.add_done_callback(_raise_exception_on_finish) - outputs = await future - record.done.extend(outputs) - while record.notify(): - pass - - def _init_input_params(self, - inputs: List[Image], - params: List[Dict] = None): - """Check and init inputs params.""" - if params is None: - params = [{}] * len(inputs) - assert len(params) == len(inputs), \ - 'different length of inputs and kwargs' - return params - - def forward(self, inputs: List[Image], params: List[Dict] = None): - """Model forward.""" - params = self._init_input_params(inputs, params) - time_start = time.perf_counter() - func_params = inspect.signature(self.model.forward).parameters - func_inputs = [inputs, params] if len(func_params) > 1 else [inputs] - outputs = self.model.forward(*func_inputs) - if isinstance(outputs[0], torch.Tensor): - outputs = [x.cpu() for x in outputs] - time_end = time.perf_counter() - logger.info(f'ImageEncoder forward {len(inputs)} images, ' - f'cost {time_end - time_start:.3f}s') + async def preprocess(self, messages: List[Dict]) -> List[Dict]: + """preprocess multimodal data in the messages.""" + future = asyncio.get_event_loop().run_in_executor( + self.executor, self.model.preprocess, messages) + future.add_done_callback(_raise_exception_on_finish) + outputs = await future return outputs - def infer(self, inputs: List[Image], params: List[Dict] = None): - """infer.""" - params = self._init_input_params(inputs, params) - results = self.forward(inputs, params) - return results + async def async_infer(self, messages: List[Dict]) -> List[Dict]: + """get multimodal embedding. + + Args: + messages (List[Dict]): a list of message, which is the output + of `preprocess()` + """ + future = asyncio.get_event_loop().run_in_executor( + self.executor, self.model.forward, messages, self.max_batch_size) + future.add_done_callback(_raise_exception_on_finish) + outputs = await future + return outputs - async def async_infer(self, - inputs: List[Image], - params: List[Dict] = None): - """async infer.""" - params = self._init_input_params(inputs, params) - outputs = asyncio.Queue() - item = (inputs, params, outputs) - if self.vision_config.thread_safe: - self._loop.call_soon_threadsafe(self._que.put_nowait, item) - else: - self.req_que.put_nowait(item) - results = await outputs.get() - return results + async def wrap_for_pytorch(self, messages: List[Dict], chat_template, + tokenizer, sequence_start) -> List[Dict]: + """ + Args: + messages (List[Dict]): a list of message, which is supposed to be + the output of `preprocess` + Returns: + a dict which will be passed to pytorch engine_instance's forward. + The dict is like the following: + Dict( + 'prompt': 'the prompt after applying chat template' + 'input_ids': [], + 'multimodal': { + 'pixel_values': torch.Tensor, + ... + ] + ) + """ + result = self.model.to_pytorch(messages, chat_template, tokenizer, + sequence_start) + # clear data + for i, message in enumerate(messages): + if isinstance(message['content'], List): + messages[i]['preprocess'] = None + return result + + async def wrap_for_turbomind(self, messages: List[Dict], chat_template, + tokenizer, sequence_start) -> Dict: + """ + Args: + messages (List[Dict]): a list of message, which is supposed to be + the output of `async_infer` + Returns: + a dict which will be passed to pytorch engine_instance's forward. + The dict is like the following: + Dict( + 'prompt': 'the prompt after applying chat template' + 'input_ids': [], + 'input_embeddings': list[torch.Tensor], + 'input_embedding_ranges': list[torch.Tensor], + ... + """ + result = self.model.to_turbomind(messages, chat_template, tokenizer, + sequence_start) + # clear data + for i, message in enumerate(messages): + if isinstance(message['content'], List): + messages[i]['preprocess'] = None + messages[i]['forward'] = None + return result diff --git a/lmdeploy/vl/model/base.py b/lmdeploy/vl/model/base.py index 9c5f5f6e6a..0ee22b4688 100644 --- a/lmdeploy/vl/model/base.py +++ b/lmdeploy/vl/model/base.py @@ -2,8 +2,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Union -import PIL -import torch +import numpy as np from mmengine import Registry from transformers import AutoConfig @@ -20,35 +19,227 @@ def __init__(self, model_path: str, with_llm: bool = False, max_memory: Dict[int, int] = None, - hf_config: AutoConfig = None): + hf_config: AutoConfig = None, + backend: str = ''): """init.""" self.model_path = model_path self.with_llm = with_llm self.max_memory = max_memory + self.backend = backend if hf_config is None: _, hf_config = get_model_arch(model_path) self.hf_config = hf_config - self.build_model() @abstractmethod - def build_model(): - """build model.""" + def build_preprocessor(self, ): + """build the preprocessor. + + NOTE: When the derived class implements this method, try not to + introduce the upper stream model repo as a thirdparty package + """ raise NotImplementedError() + def build_model(self, ): + """build the vision part of a VLM model when backend is turbomind. + + But when `with_llm=True`, load the whole VLM model + """ + if self.backend == 'turbomind' or self.with_llm: + raise NotImplementedError() + @abstractmethod + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """preprocess multimodal data in the messages. The derived class, + i.e., a specific vision model, takes the charge of image preprocessing + and the result management. + It can integrate the result into the messages list, or insert it to + the individual image item. + Args: + message(Dict): multimodal data in a dict, which is as follows: + [ + {'role': 'user', 'content': 'user prompt'}, + {'role': 'assisant', 'content': 'AI reponse'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': 'string', + }, + { + 'type': 'image', + 'image': pillow.Image, + 'key1': value1, + ... + }, + { + 'type': 'image', + 'image': pillow.Image, + 'key1': value1, + ... + }, + ... + ] + } + {....} + ] + Returns: + the message list with preprocessing results included, which is + determined by the derived classes + """ # noqa + raise NotImplementedError() + def forward(self, - images: List[PIL.Image.Image], - image_kwargs: List[Dict] = None) -> List[torch.Tensor]: - """extract image feature. + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. Args: - images (List[PIL.Image.Image]): input images - image_kwargs (List[Dict]): input kwargs for each images - + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model Return: - List[torch.Tensor]: extract image feature for each input image + the message list with forwarding results included, which is + determined by the derived classes """ - raise NotImplementedError() + if self.backend == 'turbomind': + raise NotImplementedError() + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + """pack the preprocessing results in a format compatible with what is + required by pytorch engine. ONLY implement it when the backend is + pytorch engine. + + Args: + messages(List[Dict]): the output of `preprocess` + chat_template: the chat template defined in `lmdeploy/model.py` + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + if self.backend == 'pytorch': + raise NotImplementedError() + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + """pack the forwarding results in a format compatible with what is + required by turbomind engine. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the output of `preprocess` + chat_template: the chat template defined in `lmdeploy/model.py` + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + if self.backend == 'turbomind': + raise NotImplementedError() + + @staticmethod + def collect_images(messages): + """gather all images along with their respective parameters from the + messages and compile them into a single list. Each image is converted + to RGB color space. + + Args: + messages (List[Tuple[Image, Dict]]): a list of images with their + corresponding parameters + """ # noqa + images = [] + for message in messages: + content = message['content'] + if not isinstance(content, List): + continue + images.extend([ + (x['image'], + {k: v + for k, v in x.items() if k not in {'type', 'image'}}) + for x in content if x['type'] == 'image' + ]) + return images + + @staticmethod + def to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start): + """auxiliary function to pack the preprocessing results in a format + compatible with what is required by pytorch engine. + + Args: + messages(List[Dict]): the output of `preprocess` + prompt(str): the prompt after applying chat template + IMAGE_TOKEN(str): a placeholder where image tokens will be + inserted + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + # collect all preprocessing result from messages + preps = [x['content'] for x in messages if x['role'] == 'preprocess'] + assert len(preps) == 1 + preps = preps[0] + + # split prompt into segments and validate data + segs = prompt.split(IMAGE_TOKEN) + assert len(segs) == len(preps) + 1, ( + f'the number of {IMAGE_TOKEN} is not equal ' + f'to input images, {len(segs) - 1} vs {len(preps)}') + + # calculate the image token offset for each image + input_ids = [] + for i, seg in enumerate(segs): + if i > 0 and i <= len(preps): + preps[i - 1].update(offset=len(input_ids)) + image_tokens = preps[i - 1]['image_tokens'] + image_token_id = preps[i - 1]['image_token_id'] + input_ids.extend([image_token_id] * image_tokens) + token_ids = tokenizer.encode(seg, + add_bos=((i == 0) and sequence_start)) + input_ids.extend(token_ids) + + return dict(prompt=prompt, input_ids=input_ids, multimodal=preps) + + @staticmethod + def to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start): + """auxiliary function to pack the forwarding results in a format + compatible with what is required by turbomind engine. + + Args: + messages(List[Dict]): the output of `preprocess` + prompt(str): the prompt after applying chat template + IMAGE_TOKEN(str): a placeholder where image tokens will be + inserted + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + # collect image features from messages + features = [x['content'] for x in messages if x['role'] == 'forward'] + features = features[0] + features = [x.cpu().numpy() for x in features] + + # split prompt into segments and validate data + segs = prompt.split(IMAGE_TOKEN) + assert len(segs) == len(features) + 1, ( + f'the number of {IMAGE_TOKEN} is not equal ' + f'to input images, {len(segs) - 1} vs {len(features)}') + + # tokenizer prompt, and get input_embeddings and input_embedding_ranges + input_ids = [] + begins = [] + ends = [] + IMAGE_DUMMY_TOKEN_INDEX = 0 + for i, seg in enumerate(segs): + if i > 0 and i <= len(features): + image_dim = features[i - 1].shape[0] + begins.append(len(input_ids)) + ends.append(begins[-1] + image_dim) + input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim) + seg_ids = tokenizer.encode(seg, + add_bos=((i == 0) and sequence_start)) + input_ids.extend(seg_ids) + ranges = np.stack([begins, ends], axis=1).tolist() + return dict(prompt=prompt, + input_ids=input_ids, + input_embeddings=features, + input_embedding_ranges=ranges) @classmethod def match(cls, config: AutoConfig): diff --git a/lmdeploy/vl/model/builder.py b/lmdeploy/vl/model/builder.py index 2401b42259..00e668c034 100644 --- a/lmdeploy/vl/model/builder.py +++ b/lmdeploy/vl/model/builder.py @@ -2,6 +2,8 @@ import os from typing import Optional, Union +import torch + from lmdeploy.archs import get_model_arch from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig from lmdeploy.utils import get_logger, get_model @@ -29,6 +31,7 @@ def load_vl_model(model_path: str, + backend: str, with_llm: bool = False, backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None): @@ -36,8 +39,9 @@ def load_vl_model(model_path: str, Args: model_path(str): the path or repo_id from model hub of the model - with_llm(bool): whether to remove the LLM part from the model. - When it is False, it means removing LLM part + backend(str): the name of inference backend + with_llm(bool): load LLM model or not. Set it to False for VLM + inference scenarios and True for VLM quantization backend_config: the config of the inference engine """ if not os.path.exists(model_path): @@ -49,7 +53,6 @@ def load_vl_model(model_path: str, max_memory = None if not with_llm: - import torch tp = getattr(backend_config, 'tp', 1) max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(tp)} @@ -57,30 +60,21 @@ def load_vl_model(model_path: str, kwargs = dict(model_path=model_path, with_llm=with_llm, max_memory=max_memory, - hf_config=hf_config) + hf_config=hf_config, + backend=backend) for name, module in VISION_MODELS.module_dict.items(): try: if module.match(hf_config): logger.info(f'matching vision model: {name}') - return module(**kwargs) - except Exception: - logger.error(f'matching vision model: {name} failed') + model = module(**kwargs) + model.build_preprocessor() + # build the vision part of a VLM model when backend is + # turbomind, or load the whole VLM model when `with_llm==True` + if backend == 'turbomind' or with_llm: + model.build_model() + return model + except Exception as e: + logger.error(f'build vision model {name} failed, {e}') raise raise ValueError(f'unsupported vl model with config {hf_config}') - - -def vl_model_with_tokenizer(model_path: str, with_llm: bool = True): - """load visual model.""" - vl_model = load_vl_model(model_path, with_llm).vl_model - llm = vl_model - if hasattr(vl_model, 'language_model'): # deepseek vl - llm = vl_model.language_model - if hasattr(vl_model, 'llm'): # MiniCPMV - llm = vl_model.llm - llm.config.use_cache = False - llm.half().eval() - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path, - trust_remote_code=True) - return vl_model, llm, tokenizer diff --git a/lmdeploy/vl/model/cogvlm.py b/lmdeploy/vl/model/cogvlm.py index ea5a06159e..07d97153f9 100644 --- a/lmdeploy/vl/model/cogvlm.py +++ b/lmdeploy/vl/model/cogvlm.py @@ -1,13 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings -from typing import List - -import torch -from PIL.Image import Image -from transformers import AutoModelForCausalLM +from typing import Dict, List +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel -from lmdeploy.vl.model.utils import disable_logging + +logger = get_logger('lmdeploy') @VISION_MODELS.register_module() @@ -16,7 +13,7 @@ class CogVLMVisionModel(VisonModel): _arch = 'CogVLMForCausalLM' - def build_model(self): + def build_preprocessor(self): from torchvision import transforms self.image_transform = transforms.Compose([ transforms.Resize( @@ -26,57 +23,73 @@ def build_model(self): transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) + image_size = self.hf_config.vision_config['image_size'] + patch_size = self.hf_config.vision_config['patch_size'] + self.n_token_per_image = 2 + (image_size // patch_size // 2)**2 - from accelerate import init_empty_weights, load_checkpoint_and_dispatch - from accelerate.utils import get_balanced_memory, infer_auto_device_map - with init_empty_weights(), warnings.catch_warnings(): - model = AutoModelForCausalLM.from_config(self.hf_config, - trust_remote_code=True) - if not self.with_llm: - del model.lm_head - for key in ['layers', 'norm', 'embed_tokens']: - setattr(model.model, key, None) - else: - self.vl_model = model + def build_model(self): + if self.with_llm: + from transformers import AutoModelForCausalLM + self.vl_model = AutoModelForCausalLM.from_pretrained( + self.model_path, device_map='cpu', trust_remote_code=True) + else: + raise NotImplementedError('turbomind has not supported cogvlm yet') - no_split_module_classes = ['TransformerLayer'] - max_memory = get_balanced_memory( - model, - max_memory=self.max_memory, - dtype=torch.half, - no_split_module_classes=no_split_module_classes) - device_map = infer_auto_device_map( - model, - no_split_module_classes=no_split_module_classes, - max_memory=max_memory, - dtype=torch.half) - same_device_keys = [('model.vision.linear_proj', 'model.vision.boi', - 'model.vision.eoi')] - for keys in same_device_keys: - keys = [k for k in keys if k in device_map] - if len(keys) <= 1: + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to the spec of `super().preprocess`""" + images = self.collect_images(messages) + outputs = [] + for image, _ in images: + image = image.convert('RGB') + pixel_values = self.image_transform(image) + outputs.append( + dict(pixel_values=pixel_values, + image_size=image.size, + image_tokens=self.n_token_per_image, + image_token_id=0)) + messages.append(dict(role='preprocess', content=outputs)) + return messages + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: continue - for k in keys[1:]: - device_map[k] = device_map[keys[0]] + content = [ + x['text'] for x in message['content'] if x['type'] == 'text' + ] + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + + prompt_messages.append( + dict(role='user', content=content[0], num_images=n_images)) - with disable_logging(): - load_checkpoint_and_dispatch( - model=model, - checkpoint=self.model_path, - device_map=device_map if not self.with_llm else {'': 'cpu'}, - no_split_module_classes=no_split_module_classes, - dtype=torch.half) - self.model = model.model.vision - self.model.eval() + from lmdeploy.model import Vicuna + llm_chat_template = Vicuna(eoa=chat_template.eoa, + stop_words=chat_template.stop_words) + prompt = '' + IMAGE_TOKEN = '' + for i, msg in enumerate(prompt_messages): + num_images = msg.pop('num_images', 0) + if num_images == 0: + role = msg['role'] + msg = llm_chat_template.messages2prompt([msg], sequence_start + and i == 0) + msg = dict(role=role, content=msg) + prompt_i = chat_template.messages2prompt([msg], sequence_start + and i == 0) + if num_images > 0: + prompt_i = (IMAGE_TOKEN * num_images) + prompt_i + prompt += prompt_i + return prompt, IMAGE_TOKEN - @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - outputs = [x.convert('RGB') for x in images] - outputs = [self.image_transform(x) for x in outputs] - outputs = torch.stack(outputs, dim=0).to(device='cuda:0', - dtype=torch.half) - outputs = self.model(outputs) - outputs = torch.split(outputs, 1, dim=0) - outputs = [x.squeeze() for x in outputs] - return outputs + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/deepseek.py b/lmdeploy/vl/model/deepseek.py index bfbf03f01e..9780744cf2 100644 --- a/lmdeploy/vl/model/deepseek.py +++ b/lmdeploy/vl/model/deepseek.py @@ -1,15 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. - import warnings -from typing import List +from typing import Dict, List import torch -from PIL.Image import Image from transformers import AutoModelForCausalLM +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import disable_logging +logger = get_logger('lmdeploy') + def check_deepseek_vl_install(): """check deepseek_vl install.""" @@ -18,8 +19,8 @@ def check_deepseek_vl_install(): except ImportError: raise ImportError( 'To use DeepSeekVLModel, please install deepseek_vl by ' - 'pip install git+https://github.com/deepseek-ai/DeepSeek-VL.git' - ' --no-deps') + '`pip install git+https://github.com/deepseek-ai/DeepSeek-VL.git' + ' --no-deps`') @VISION_MODELS.register_module() @@ -28,18 +29,22 @@ class DeepSeekVisionModel(VisonModel): _arch = 'MultiModalityCausalLM' - def build_model(self): + def build_preprocessor(self): check_deepseek_vl_install() - # empty init - from accelerate import init_empty_weights from deepseek_vl.models import VLChatProcessor + self.image_processor = VLChatProcessor.from_pretrained( + self.model_path).image_processor + + def build_model(self): + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" + from accelerate import init_empty_weights with init_empty_weights(): warnings.simplefilter('ignore') model = AutoModelForCausalLM.from_pretrained(self.model_path) + self.vl_model = model if not self.with_llm: del model.language_model - else: - self.vl_model = model from accelerate.utils import get_balanced_memory, infer_auto_device_map max_memory = get_balanced_memory(model, @@ -79,23 +84,111 @@ def build_model(self): device_map=device_map if not self.with_llm else {'': 'cpu'}, dtype=torch.half) + self.model = model.eval() self.vision_model = model.vision_model.eval() self.aligner = model.aligner.eval() - self.image_processor = VLChatProcessor.from_pretrained( - self.model_path).image_processor + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to the spec of `super.preprocess()""" + images = self.collect_images(messages) + outputs = [] + for image, _ in images: + image = image.convert('RGB') + pixel_values = self.image_processor( + [image], return_tensors='pt').pixel_values + outputs.append( + dict( + pixel_values=pixel_values, + image_size=image.size, + # refer to https://github.com/deepseek-ai/DeepSeek-VL/blob/main/deepseek_vl/models/processing_vlm.py # noqa + # which is hardcoded 576 + image_tokens=576, + image_token_id=0)) + messages.append(dict(role='preprocess', content=outputs)) + return messages @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - outputs = [x.convert('RGB') for x in images] - pixel_values = self.image_processor(outputs, - return_tensors='pt').pixel_values - pixel_values = pixel_values.to(device=next( - self.vision_model.parameters()).device, - dtype=torch.float16) - # [b x n_images, T2, D] - images_embeds = self.aligner(self.vision_model(pixel_values)) - - outputs = torch.split(images_embeds, 1, dim=0) - outputs = [x.squeeze() for x in outputs] - return outputs + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(device=next( + self.vision_model.parameters()).device, + dtype=torch.float16) + # [b x n_images, T2, D] + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.aligner(self.vision_model(pixel_values)) + feats = torch.split(feats, 1, dim=0) + outputs.extend([x.squeeze() for x in feats]) + messages.append(dict(role='forward', content=outputs)) + return messages + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + # apply chat template to get the prompt + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: + continue + content = [ + x['text'] for x in message['content'] if x['type'] == 'text' + ] + content = content[0] + n_image = sum( + [1 for x in message['content'] if x['type'] == 'image']) + n_placeholder = content.count(IMAGE_TOKEN) + if n_placeholder == 0: + logger.warning( + f"""for deepseek-vl model, the user should insert the {IMAGE_TOKEN} + to user prompt manually, please read https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html + for more details.""") # noqa + if n_placeholder != 0 and n_placeholder != n_image: + logger.error( + f'unmatched placeholder and image: {n_placeholder} vs ' + f'{n_image}. Ignore the placeholder') + content = content.replace(IMAGE_TOKEN, '') + n_placeholder = 0 + if n_placeholder == 0: + if n_image == 1: + content = f'{IMAGE_TOKEN}{content}' + else: + content = ''.join([ + f'{IMAGE_TOKEN} is Figure {str(i)}.\n' + for i in range(n_image) + ]) + content + prompt_messages.append(dict(role='user', content=content)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/glm_4v.py b/lmdeploy/vl/model/glm_4v.py index 34e060f4c9..813813bf09 100644 --- a/lmdeploy/vl/model/glm_4v.py +++ b/lmdeploy/vl/model/glm_4v.py @@ -1,77 +1,30 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List -import warnings -from typing import List - -import torch -from PIL.Image import Image from transformers import AutoConfig +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel -from lmdeploy.vl.model.utils import disable_logging + +logger = get_logger('lmdeploy') @VISION_MODELS.register_module() class GLM4VisionModel(VisonModel): """glm-4v-9b vision model.""" - _arch = 'ChatGLMModel' + _arch = ['ChatGLMModel', 'ChatGLMForConditionalGeneration'] @classmethod def match(cls, config: AutoConfig): """check whether the config match the model.""" arch = config.architectures[0] - if arch == cls._arch and hasattr(config, 'vision_config'): + if arch in cls._arch and hasattr(config, 'vision_config'): return True return False - def build_model(self): - from accelerate import init_empty_weights, load_checkpoint_and_dispatch - from accelerate.utils import infer_auto_device_map + def build_preprocessor(self): from torchvision import transforms - - with init_empty_weights(), warnings.catch_warnings(): - warnings.simplefilter('ignore') - from transformers import AutoModelForCausalLM - model = AutoModelForCausalLM.from_config(self.hf_config, - trust_remote_code=True) - if not self.with_llm: - del model.transformer.embedding - del model.transformer.rotary_pos_emb - del model.transformer.encoder - del model.transformer.output_layer - else: - self.vl_model = model - - no_split_module_classes = ['TransformerLayer'] - - device_map = infer_auto_device_map( - model, - no_split_module_classes=no_split_module_classes, - max_memory=self.max_memory, - dtype=torch.half) - - same_device_keys = [ - ('transformer.vision.linear_proj', 'transformer.vision.boi', - 'transformer.vision.eoi') - ] - for keys in same_device_keys: - keys = [k for k in keys if k in device_map] - if len(keys) <= 1: - continue - for k in keys[1:]: - device_map[k] = device_map[keys[0]] - - with disable_logging(): - load_checkpoint_and_dispatch( - model=model, - checkpoint=self.model_path, - device_map=device_map if not self.with_llm else {'': 'cpu'}, - no_split_module_classes=no_split_module_classes, - dtype=torch.half) - - model.eval() - self.model = model self.image_transform = transforms.Compose([ transforms.Resize( (self.hf_config.vision_config['image_size'], ) * 2, @@ -80,15 +33,65 @@ def build_model(self): transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) + image_size = self.hf_config.vision_config['image_size'] + patch_size = self.hf_config.vision_config['patch_size'] + self.n_token_per_image = 2 + (image_size // patch_size // 2)**2 + + def build_model(self): + if self.with_llm: + from transformers import AutoModelForCausalLM + self.vl_model = AutoModelForCausalLM.from_pretrained( + self.model_path, device_map='cpu', trust_remote_code=True) + else: + raise NotImplementedError('turbomind has not supported glm4v yet') + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to the spec of `super.preprocess()""" + outputs = [] + for message in messages: + if not isinstance(message['content'], List): + continue + images = [ + x['image'] for x in message['content'] if x['type'] == 'image' + ] + if len(images) > 1: + logger.warning( + f'glm4v does not support the input of multiple images' + f' in a single chat round, but got {len(images)} images.') + # we still pass all the images to the model and let the + # model decide what to do + images = [x.convert('RGB') for x in images] + pixel_values = [self.image_transform(x) for x in images] + outputs.extend([ + dict(pixel_values=_2, + image_size=_1.size, + image_tokens=self.n_token_per_image, + image_token_id=0) for _1, _2 in zip(images, pixel_values) + ]) + messages.append(dict(role='preprocess', content=outputs)) + return messages + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + content = message['content'] + if isinstance(content, str): + prompt_messages.append(message) + continue + elif message['role'] in ['preprocess', 'forward']: + continue + prompt = [x['text'] for x in content if x['type'] == 'text'] + n_images = len([1 for x in content if x['type'] == 'image']) + prompt = ''.join([f'{IMAGE_TOKEN}\n'] * n_images) + prompt[0] + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN - @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - outputs = [x.convert('RGB') for x in images] - outputs = [self.image_transform(x) for x in outputs] - outputs = torch.stack(outputs, dim=0).to(device='cuda:0', - dtype=torch.half) - outputs = self.model.transformer.vision(outputs) - outputs = torch.split(outputs, 1, dim=0) - outputs = [x.squeeze() for x in outputs] - return outputs + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/internvl.py b/lmdeploy/vl/model/internvl.py index fa67192f11..979b8d1a39 100644 --- a/lmdeploy/vl/model/internvl.py +++ b/lmdeploy/vl/model/internvl.py @@ -1,10 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. - from typing import Dict, List import torch -from PIL.Image import Image -from transformers import AutoModel, CLIPImageProcessor +from transformers import AutoConfig, AutoModel, CLIPImageProcessor from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel @@ -80,34 +78,16 @@ class InternVLVisionModel(VisonModel): _arch = 'InternVLChatModel' - def build_model(self): - """Load model.""" - from accelerate import init_empty_weights - with init_empty_weights(): - config = self.hf_config - # transformers below 4.37.0 may raise error about flash_attn - config.llm_config.attn_implementation = 'eager' - model = AutoModel.from_config(config, trust_remote_code=True) - if not self.with_llm: - del model.language_model - else: - self.vl_model = model - model.half() + def __init__(self, + model_path: str, + with_llm: bool = False, + max_memory: Dict[int, int] = None, + hf_config: AutoConfig = None, + backend: str = ''): + super().__init__(model_path, with_llm, max_memory, hf_config, backend) - from accelerate import load_checkpoint_and_dispatch - with disable_logging(): - load_checkpoint_and_dispatch( - model=model, - checkpoint=self.model_path, - device_map='auto' if not self.with_llm else {'': 'cpu'}, - max_memory=self.max_memory, - no_split_module_classes=['InternVisionEncoderLayer'], - dtype=torch.half) - - # We need eval mode to freeze the weights in model, thus, - # avoid randomness in inference. - self.model = model.eval() - self.config = config + def build_preprocessor(self): + self.config = self.hf_config dynamic_image_size = getattr(self.config, 'dynamic_image_size', False) image_processor = None try: @@ -131,62 +111,180 @@ def build_model(self): T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) + self.processor = self._preprocess_v1_5 self._forward_func = self._forward_v1_5 else: + self.processor = self._preprocess self.image_processor = image_processor self._forward_func = self._forward - def _preprocess_v1_5(self, images: List[Image], params: List[Dict] = None): - if params is not None: - assert len(images) == len( - params), 'different length of images and params' - else: - params = [{}] * len(images) + force_image_size = self.hf_config.force_image_size + patch_size = self.hf_config.vision_config.patch_size + downsample_ratio = self.hf_config.downsample_ratio + self.image_tokens_per_patch = int( + (force_image_size // patch_size)**2 * (downsample_ratio**2)) - image_res = {'low': 6, 'medium': 12, 'high': 24} + def build_model(self): + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" + from accelerate import init_empty_weights + with init_empty_weights(): + # transformers below 4.37.0 may raise error about flash_attn + self.config.llm_config.attn_implementation = 'eager' + model = AutoModel.from_config(self.config, trust_remote_code=True) + self.vl_model = model + if not self.with_llm: + del model.language_model - outputs = [] - for image, param in zip(images, params): - max_num = param.get('max_dynamic_patch') - if max_num is None or not isinstance(max_num, int): - res_key = param.get('detail', 'default') - max_num = image_res.get(res_key, self.config.max_dynamic_patch) - out = dynamic_preprocess( - image, - min_num=self.config.min_dynamic_patch, - max_num=max_num, - image_size=self.config.vision_config.image_size, - use_thumbnail=self.config.use_thumbnail) - out = [self.transform(x) for x in out] - out = torch.stack(out) # (patch) x c x h x w - outputs.append(out) - return outputs + model.half() + from accelerate import load_checkpoint_and_dispatch + with disable_logging(): + load_checkpoint_and_dispatch( + model=model, + checkpoint=self.model_path, + device_map='auto' if not self.with_llm else {'': 'cpu'}, + max_memory=self.max_memory, + no_split_module_classes=['InternVisionEncoderLayer'], + dtype=torch.half) + + # We need eval mode to freeze the weights in model, thus, + # avoid randomness in inference. + self.model = model.eval() + + def _preprocess_v1_5(self, image, params=None): + image_res = {'low': 6, 'medium': 12, 'high': 24} + max_num = params.get('max_dynamic_patch') + if max_num is None or not isinstance(max_num, int): + res_key = params.get('detail', 'default') + max_num = image_res.get(res_key, self.config.max_dynamic_patch) + out = dynamic_preprocess( + image, + min_num=self.config.min_dynamic_patch, + max_num=max_num, + image_size=self.config.vision_config.image_size, + use_thumbnail=self.config.use_thumbnail) + pixel_values = [self.transform(x) for x in out] + # (patch) x c x h x w + pixel_values = torch.stack(pixel_values) + return pixel_values - def _forward_v1_5(self, images: List[Image], params: List[Dict] = None): + def _forward_v1_5(self, inputs, max_batch_size): """forward for internvl-chat-v1-5.""" - outputs = self._preprocess_v1_5(images, params) - split = [x.shape[0] for x in outputs] - outputs = torch.cat(outputs, dim=0) - outputs = outputs.to(self.model.device, dtype=torch.float16) - outputs = self.model.extract_feature(outputs) - outputs = torch.split(outputs, split, dim=0) - outputs = [x.reshape(-1, x.shape[-1]) for x in outputs] + assert all(x.get('pixel_values') is not None for x in inputs) + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + split = [x.shape[0] for x in pixel_values] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(self.model.device, + dtype=torch.float16) + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.model.extract_feature(pixel_values) + feats = torch.split(feats, split, dim=0) + outputs.extend([x.reshape(-1, x.shape[-1]) for x in feats]) return outputs - def _forward(self, images: List[Image], params: List[Dict] = None): + def _preprocess(self, image, params=None): """forward for internvl-chat-v1-1, internvl-chat-v1-2.""" - pixel_values = self.image_processor(images=images, + pixel_values = self.image_processor(images=image, return_tensors='pt').pixel_values - pixel_values = pixel_values.to(self.model.device, dtype=torch.float16) - outputs = self.model.extract_feature(pixel_values) - outputs = torch.split(outputs, 1, dim=0) - outputs = [x.squeeze() for x in outputs] + return pixel_values + + def _forward(self, inputs, max_batch_size): + """forward for internvl-chat-v1-1, internvl-chat-v1-2.""" + assert all(x.get('pixel_values') is not None for x in inputs) + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(self.model.device, + dtype=torch.float16) + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.model.extract_feature(pixel_values) + feats = torch.split(feats, 1, dim=0) + outputs.extend([x.squeeze() for x in feats]) return outputs + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to `super.preprocess() for spec.""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + pixel_values = self.processor(image, params) + image_tokens = (pixel_values.shape[0] * + self.image_tokens_per_patch) + outputs.append( + dict(pixel_values=pixel_values, + image_tokens=image_tokens, + image_token_id=0, + image_size=image.size)) + messages.append(dict(role='preprocess', content=outputs)) + return messages + @torch.no_grad() def forward(self, - images: List[Image], - params: List[Dict] = None) -> List[torch.Tensor]: - """forward.""" - images = [x.convert('RGB') for x in images] - return self._forward_func(images, params) + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = self._forward_func(inputs, max_batch_size) + messages.append(dict(role='forward', content=outputs)) + return messages + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['preprocess', 'forward']: + continue + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + x['text'] for x in message['content'] if x['type'] == 'text' + ] + prompt = content[0] + if IMAGE_TOKEN in prompt and f'{IMAGE_TOKEN}' not in prompt: + prompt = prompt.replace(f'{IMAGE_TOKEN}', + f'{IMAGE_TOKEN}') + prompt = prompt.replace('', '') + prompt = prompt.replace('', '') + prompt = prompt.replace('', '') + elif IMAGE_TOKEN not in prompt: + prompt = f'{IMAGE_TOKEN * n_images}\n' + prompt + else: + pass + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/internvl_llava.py b/lmdeploy/vl/model/internvl_llava.py index f607082b18..17a12f71ca 100644 --- a/lmdeploy/vl/model/internvl_llava.py +++ b/lmdeploy/vl/model/internvl_llava.py @@ -2,14 +2,13 @@ import warnings from contextlib import contextmanager -from typing import List, Union +from typing import Dict, List import torch -from PIL.Image import Image from transformers import AutoConfig, AutoModelForCausalLM from lmdeploy.utils import get_logger -from lmdeploy.vl.model.base import VISION_MODELS, VisonModel +from lmdeploy.vl.model.llava import VISION_MODELS, LlavaVisionModel from lmdeploy.vl.model.utils import rewrite_ctx from .utils import disable_logging, disable_transformers_logging @@ -18,14 +17,13 @@ def check_llava_install(): - """check llava install.""" try: from llava.model.multimodal_encoder.clip_encoder import \ InternVisionModel # noqa: F401 except ImportError: raise ImportError( 'To use LlavaVLModel, please install llava by ' - 'pip install "git+https://github.com/OpenGVLab/InternVL#subdirectory=internvl_chat_llava" --no-deps' # noqa: E501 + '`pip install git+https://github.com/OpenGVLab/InternVL#subdirectory=internvl_chat_llava --no-deps`' # noqa: E501 ) @@ -65,7 +63,7 @@ def init_empty_vit(): @VISION_MODELS.register_module() -class InternVLLlavaVisionModel(VisonModel): +class InternVLLlavaVisionModel(LlavaVisionModel): """Llava visual model.""" @classmethod @@ -78,9 +76,12 @@ def match(cls, config: AutoConfig): return True return False + def build_preprocessor(self): + return super().build_preprocessor() + def build_model(self): - """build model & load weights.""" - # check llava install + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" check_llava_install() # currently, only support llava llama from llava.model.language_model.llava_llama import ( # noqa @@ -98,13 +99,12 @@ def build_model(self): } # disable vision part quantization model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True) + self.vl_model = model if not self.with_llm: del model.lm_head del model.model.embed_tokens del model.model.layers del model.model.norm - else: - self.vl_model = model with init_empty_vit(): vision_tower = model.get_vision_tower() @@ -137,42 +137,43 @@ def build_model(self): self.vision_tower = model.model.vision_tower.eval() self.mm_projector = model.model.mm_projector.eval() - def encode_images(self, images: torch.Tensor) -> torch.Tensor: - """encode images.""" - image_features = self.vision_tower(images) - image_features = self.mm_projector(image_features) - return image_features - - def preprocess( - self, - images: List[Image]) -> Union[torch.Tensor, List[torch.Tensor]]: - """preprocess.""" - # TODO: gpu processor - from llava.mm_utils import process_images - images = [x.convert('RGB') for x in images] - image_processor = self.vision_tower.image_processor - outputs = process_images(images, image_processor, self.config) - return outputs + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to `super().preprocess() for spec.""" + return super().preprocess(messages) @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - images = self.preprocess(images) - if isinstance(images, list): - images = [ - x.to(self.vision_tower.device, dtype=torch.float16) - for x in images + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] ] - else: - images = images.to(self.vision_tower.device, dtype=torch.float16) - - if type(images) is list or images.ndim == 5: - concat_images = torch.cat([image for image in images], dim=0) - image_features = self.encode_images(concat_images) - split_sizes = [image.shape[0] for image in images] - image_features = torch.split(image_features, split_sizes, dim=0) - image_features = [x.flatten(0, 1) for x in image_features] - else: - image_features = self.encode_images(images) - image_features = [x for x in image_features] - return image_features + split_sizes = [x.shape[0] for x in pixel_values] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(device=self.vision_tower.device, + dtype=torch.float16) + logger.info(f'vision forward shape: {pixel_values.shape}') + if pixel_values.ndim == 5: + feats = self.encode_images(pixel_values) + feats = torch.split(feats, split_sizes, dim=0) + feats = [x.flatten(0, 1) for x in feats] + else: + feats = self.encode_images(pixel_values) + feats = [x for x in feats] + outputs.extend(feats) + messages.append(dict(role='forward', content=outputs)) + return messages diff --git a/lmdeploy/vl/model/llava.py b/lmdeploy/vl/model/llava.py index 0b18f460cd..7ad919bef7 100644 --- a/lmdeploy/vl/model/llava.py +++ b/lmdeploy/vl/model/llava.py @@ -1,16 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -# Modified from -# https://github.com/haotian-liu/LLaVA.git +# Modified from https://github.com/haotian-liu/LLaVA.git +import ast +import math import warnings from contextlib import contextmanager -from typing import List, Union +from typing import Dict, List import torch -from PIL.Image import Image +from PIL import Image from transformers import AutoConfig, AutoModelForCausalLM from lmdeploy.utils import get_logger -from lmdeploy.vl.model.base import VISION_MODELS, VisonModel +from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel from lmdeploy.vl.model.utils import disable_logging, rewrite_ctx logger = get_logger('lmdeploy') @@ -23,16 +24,14 @@ def check_llava_install(): except ImportError: raise ImportError( 'To use LlavaVLModel, please install llava by ' - 'pip install git+https://github.com/haotian-liu/LLaVA.git --no-deps' # noqa: E501 + '`pip install git+https://github.com/haotian-liu/LLaVA.git --no-deps`' # noqa: E501 ) def _clip_vision_tower_load_model(self, **kwargs): logger.info(f'CLIPVisionTower.load_model: {self.vision_tower_name}') - from transformers import (CLIPImageProcessor, CLIPVisionConfig, - CLIPVisionModel) - self.image_processor = CLIPImageProcessor.from_pretrained( - self.vision_tower_name) + from transformers import CLIPVisionConfig, CLIPVisionModel + config = CLIPVisionConfig.from_pretrained(self.vision_tower_name) self.vision_tower = CLIPVisionModel._from_config(config=config) self.vision_tower.requires_grad_(False) @@ -53,8 +52,166 @@ def init_llava_vision_tower(config): yield +def select_best_resolution(original_size, possible_resolutions): + """Selects the best resolution from a list of possible resolutions based on + the original size. + + Args: + original_size (tuple): The original size of the image in the format (width, height). + possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + + Returns: + tuple: The best fit resolution in the format (width, height). + """ # noqa + original_width, original_height = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float('inf') + + for width, height in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int( + original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, + original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + return best_fit + + +def resize_and_pad_image(image, target_resolution): + """Resize and pad an image to a target resolution while maintaining aspect + ratio. + + Args: + image (PIL.Image.Image): The input image. + target_resolution (tuple): The target resolution (width, height) of the image. + + Returns: + PIL.Image.Image: The resized and padded image. + """ # noqa + original_width, original_height = image.size + target_width, target_height = target_resolution + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + # Resize the image + resized_image = image.resize((new_width, new_height)) + + new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + new_image.paste(resized_image, (paste_x, paste_y)) + + return new_image + + +def divide_to_patches(image, patch_size): + """Divides an image into patches of a specified size. + + Args: + image (PIL.Image.Image): The input image. + patch_size (int): The size of each patch. + + Returns: + list: A list of PIL.Image.Image objects representing the patches. + """ + patches = [] + width, height = image.size + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + box = (j, i, j + patch_size, i + patch_size) + patch = image.crop(box) + patches.append(patch) + + return patches + + +def process_anyres_image(image, processor, grid_pinpoints): + """Process an image with variable resolutions. + + Args: + image (PIL.Image.Image): The input image to be processed. + processor: The image processor object. + grid_pinpoints (str): A string representation of a list of possible resolutions. + + Returns: + torch.Tensor: A tensor containing the processed image patches. + """ # noqa + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + best_resolution = select_best_resolution(image.size, possible_resolutions) + image_padded = resize_and_pad_image(image, best_resolution) + + patches = divide_to_patches(image_padded, processor.crop_size['height']) + + image_original_resize = image.resize( + (processor.size['shortest_edge'], processor.size['shortest_edge'])) + + image_patches = [image_original_resize] + patches + image_patches = [ + processor.preprocess(image_patch, + return_tensors='pt')['pixel_values'][0] + for image_patch in image_patches + ] + return torch.stack(image_patches, dim=0) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def process_images(images, image_processor, model_cfg): + image_aspect_ratio = getattr(model_cfg, 'image_aspect_ratio', None) + new_images = [] + if image_aspect_ratio == 'pad': + for image in images: + image = expand2square( + image, tuple(int(x * 255) for x in image_processor.image_mean)) + image = image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + new_images.append(image) + elif image_aspect_ratio == 'anyres': + for image in images: + image = process_anyres_image(image, image_processor, + model_cfg.image_grid_pinpoints) + new_images.append(image) + else: + return image_processor(images, return_tensors='pt')['pixel_values'] + if all(x.shape == new_images[0].shape for x in new_images): + new_images = torch.stack(new_images, dim=0) + return new_images + + @VISION_MODELS.register_module() -class LlavaVisionModel(VisonModel): +class LlavaVisionModel(LlavaHfVisionModel): """Llava visual model.""" @classmethod @@ -73,9 +230,20 @@ def match(cls, config: AutoConfig): return True return False + def build_preprocessor(self): + from transformers import CLIPImageProcessor + self.image_processor = CLIPImageProcessor.from_pretrained( + self.hf_config.mm_vision_tower) + config = AutoConfig.from_pretrained(self.hf_config.mm_vision_tower) + image_size = config.vision_config.image_size + patch_size = config.vision_config.patch_size + self.n_token_per_image = (image_size // patch_size)**2 + if self.hf_config.mm_vision_select_feature == 'cls_patch': + self.n_token_per_image += 1 + def build_model(self): - """build model & load weights.""" - # check llava install + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" check_llava_install() self.arch = self.hf_config.architectures[0] @@ -104,15 +272,13 @@ def build_model(self): model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True) + self.vl_model = model if not self.with_llm: # remove the LLM part from llava model. - # Instead, Load the LLM part to turbomind engine del model.lm_head del model.model.embed_tokens del model.model.layers del model.model.norm - else: - self.vl_model = model # init empty vision_tower, the embedding layer in CLIPVisionModel # can't init right under init_empty_weights @@ -143,101 +309,113 @@ def encode_images(self, images: torch.Tensor) -> torch.Tensor: image_features = self.mm_projector(image_features) return image_features - def preprocess( - self, - images: List[Image]) -> Union[torch.Tensor, List[torch.Tensor]]: - """preprocess.""" - # TODO: gpu processor - from llava.mm_utils import process_images - images = [x.convert('RGB') for x in images] - image_processor = self.vision_tower.image_processor - outputs = process_images(images, image_processor, self.config) - return outputs + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to `super().preprocess() for spec.""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + pixel_values = process_images([image], self.image_processor, + self.config) + outputs.append( + dict(pixel_values=pixel_values, + image_size=image.size, + image_tokens=self.n_token_per_image, + image_token_id=0)) + messages.append(dict(role='preprocess', content=outputs)) + return messages @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ from llava.model.llava_arch import (get_anyres_image_grid_shape, unpad_image) - image_sizes = [x.size for x in images] - images = self.preprocess(images) - if isinstance(images, list): - images = [ - x.to(device=self.vision_tower.device, dtype=torch.float16) - for x in images + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + image_sizes = [ + x['image_size'] for x in inputs[idx:idx + max_batch_size] ] - else: - images = images.to(device=self.vision_tower.device, - dtype=torch.float16) - if type(images) is list or images.ndim == 5: - if type(images) is list: - images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] - concat_images = torch.cat([image for image in images], dim=0) - image_features = self.encode_images(concat_images) - split_sizes = [image.shape[0] for image in images] - image_features = torch.split(image_features, split_sizes, dim=0) - mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', - 'flat') - image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', - 'square') - if mm_patch_merge_type == 'flat': - image_features = [x.flatten(0, 1) for x in image_features] - elif mm_patch_merge_type.startswith('spatial'): - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - height = width = self.vision_tower.num_patches_per_side - assert height * width == base_image_feature.shape[0] - if image_aspect_ratio == 'anyres': - num_patch_width, num_patch_height = \ - get_anyres_image_grid_shape( - image_sizes[image_idx], - self.config.image_grid_pinpoints, - self.vision_tower.config.image_size) - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, - width, -1) - else: - raise NotImplementedError - if 'unpad' in mm_patch_merge_type: - image_feature = image_feature.permute( - 4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, - 2).flatten( - 2, 3) - image_feature = unpad_image( - image_feature, image_sizes[image_idx]) - image_feature = torch.cat(( - image_feature, - self.model.image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1).to( - image_feature.device)), - dim=-1) - image_feature = image_feature.flatten(1, - 2).transpose( - 0, 1) + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + if pixel_values[0].ndim == 5: + split_sizes = [x.shape[1] for x in pixel_values] + pixel_values = torch.cat([x for x in pixel_values], dim=1) + logger.info(f'vision forward shape: {pixel_values.shape}') + pixel_values = pixel_values.squeeze(0) + pixel_values = pixel_values.to(device=self.vision_tower.device, + dtype=torch.float16) + feats = self.encode_images(pixel_values) + feats = torch.split(feats, split_sizes, dim=0) + mm_patch_merge_type = getattr(self.config, + 'mm_patch_merge_type', 'flat') + image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', + 'square') + if mm_patch_merge_type == 'flat': + outputs.expand([x.flatten(0, 1) for x in feats]) + elif mm_patch_merge_type.startswith('spatial'): + for img_idx, feat in enumerate(feats): + if feat.shape[0] > 1: + base_feat = feat[0] + feat = feat[1:] + height = self.vision_tower.num_patches_per_side + width = self.vision_tower.num_patches_per_side + assert height * width == base_feat.shape[0] + if image_aspect_ratio == 'anyres': + num_patch_width, num_patch_height = \ + get_anyres_image_grid_shape( + image_sizes[img_idx], + self.config.image_grid_pinpoints, + self.vision_tower.config.image_size) + feat = feat.view(num_patch_height, + num_patch_width, height, + width, -1) + else: + raise NotImplementedError + if 'unpad' in mm_patch_merge_type: + feat = feat.permute(4, 0, 2, 1, 3).contiguous() + feat = feat.flatten(1, 2).flatten(2, 3) + feat = unpad_image(feat, image_sizes[img_idx]) + feat = torch.cat( + (feat, self.model. + image_newline[:, None, None].expand( + *feat.shape[:-1], 1).to(feat.device)), + dim=-1) + feat = feat.flatten(1, 2).transpose(0, 1) + else: + feat = feat.permute(0, 2, 1, 3, 4).contiguous() + feat = feat.flatten(0, 3) + feat = torch.cat((base_feat, feat), dim=0) else: - image_feature = image_feature.permute( - 0, 2, 1, 3, 4).contiguous() - image_feature = image_feature.flatten(0, 3) - image_feature = torch.cat( - (base_image_feature, image_feature), dim=0) - else: - image_feature = image_feature[0] - if 'unpad' in mm_patch_merge_type: - image_feature = torch.cat( - (image_feature, - self.model.image_newline[None].to( - image_feature.device)), - dim=0) - new_image_features.append(image_feature) - image_features = new_image_features + feat = feat[0] + if 'unpad' in mm_patch_merge_type: + feat = torch.cat( + (feat, self.model.image_newline[None].to( + feat.device)), + dim=0) + outputs.append(feat) + else: + raise ValueError('Unexpected mm_patch_merge_type: ' + f'{self.config.mm_patch_merge_type}') else: - raise ValueError('Unexpected mm_patch_merge_type: ' - f'{self.config.mm_patch_merge_type}') - else: - image_features = self.encode_images(images) - image_features = [x for x in image_features] - return image_features + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(device=self.vision_tower.device, + dtype=torch.float16) + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.encode_images(pixel_values) + outputs.extend([x for x in feats]) + messages.append(dict(role='forward', content=outputs)) + return messages diff --git a/lmdeploy/vl/model/llava_hf.py b/lmdeploy/vl/model/llava_hf.py index 31be101ae8..c4e3c90bfb 100644 --- a/lmdeploy/vl/model/llava_hf.py +++ b/lmdeploy/vl/model/llava_hf.py @@ -1,15 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. - import warnings -from typing import List +from typing import Dict, List import torch -from PIL.Image import Image from transformers import AutoProcessor +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import disable_logging +logger = get_logger('lmdeploy') + @VISION_MODELS.register_module() class LlavaHfVisionModel(VisonModel): @@ -17,19 +18,31 @@ class LlavaHfVisionModel(VisonModel): _arch = 'LlavaForConditionalGeneration' + def build_preprocessor(self): + processor = AutoProcessor.from_pretrained(self.model_path, + trust_remote_code=True) + if hasattr(processor, 'tokenizer'): + del processor.tokenizer + processor.prtokenizer = None + self.processor = processor.image_processor + image_size = self.hf_config.vision_config.image_size + patch_size = self.hf_config.vision_config.patch_size + self.n_token_per_image = (image_size // patch_size)**2 + if self.hf_config.vision_feature_select_strategy == 'full': + self.n_token_per_image += 1 + def build_model(self): + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" from accelerate import init_empty_weights, load_checkpoint_and_dispatch with init_empty_weights(), warnings.catch_warnings(): warnings.simplefilter('ignore') from transformers import LlavaForConditionalGeneration model = LlavaForConditionalGeneration._from_config(self.hf_config) + self.vl_model = model if not self.with_llm: del model.language_model - for key in ['language_model']: - setattr(model, key, None) - else: - self.vl_model = model # fix for llava-hf/llava-interleave-qwen-7b-hf setattr(model.config, 'tie_word_embeddings', False) @@ -45,35 +58,97 @@ def build_model(self): dtype=torch.half) model.eval() self.model = model - # processor - processor = AutoProcessor.from_pretrained(self.model_path, - trust_remote_code=True) - if hasattr(processor, 'tokenizer'): - del processor.tokenizer - processor.prtokenizer = None - self.processor = processor.image_processor + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to `super.preprocess() for spec.""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + pixel_values = self.processor( + image, return_tensors='pt', + input_data_format='channels_last').pixel_values + outputs.append( + dict(pixel_values=pixel_values, + image_size=image.size, + image_tokens=self.n_token_per_image, + image_token_id=0)) + messages.append(dict(role='preprocess', content=outputs)) + return messages @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - pixel_values = self.processor( - images, return_tensors='pt', - input_data_format='channels_last')['pixel_values'] - pixel_values = pixel_values.to(device=self.model.device, - dtype=self.model.dtype) - image_outputs = self.model.vision_tower.forward( - pixel_values, output_hidden_states=True) - image_features = image_outputs.hidden_states[ - self.hf_config.vision_feature_layer] - if self.hf_config.vision_feature_select_strategy == 'default': - image_features = image_features[:, 1:] - elif self.hf_config.vision_feature_select_strategy == 'full': - image_features = image_features - else: - raise ValueError( - 'Unexpected select feature strategy: ' - f'{self.hf_config.vision_feature_select_strategy}') - image_features = self.model.multi_modal_projector(image_features) - outputs = torch.split(image_features, 1, dim=0) - outputs = [x.squeeze() for x in outputs] - return outputs + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(device=self.model.device, + dtype=self.model.dtype) + logger.info(f'vision forward shape: {pixel_values.shape}') + image_outputs = self.model.vision_tower.forward( + pixel_values, output_hidden_states=True) + image_features = image_outputs.hidden_states[ + self.hf_config.vision_feature_layer] + if self.hf_config.vision_feature_select_strategy == 'default': + image_features = image_features[:, 1:] + elif self.hf_config.vision_feature_select_strategy == 'full': + image_features = image_features + else: + raise ValueError( + 'Unexpected select feature strategy: ' + f'{self.hf_config.vision_feature_select_strategy}') + image_features = self.model.multi_modal_projector(image_features) + image_features = torch.split(image_features, 1, dim=0) + outputs.extend([x.squeeze() for x in image_features]) + messages.append(dict(role='forward', content=outputs)) + return messages + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: + continue + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + item['text'] for item in message['content'] + if item['type'] == 'text' + ] + prompt = (IMAGE_TOKEN + '\n') * n_images + content[0] + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/llava_next.py b/lmdeploy/vl/model/llava_next.py index 9223ebea4f..d355a48d60 100644 --- a/lmdeploy/vl/model/llava_next.py +++ b/lmdeploy/vl/model/llava_next.py @@ -1,46 +1,51 @@ # Copyright (c) OpenMMLab. All rights reserved. - +import itertools import warnings -from typing import List +from typing import Dict, List import torch -from PIL.Image import Image -from transformers import AutoProcessor -from lmdeploy.vl.model.base import VISION_MODELS, VisonModel +from lmdeploy.utils import get_logger +from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel from lmdeploy.vl.model.utils import disable_logging +logger = get_logger('lmdeploy') + @VISION_MODELS.register_module() -class LlavaNextVisionModel(VisonModel): +class LlavaNextVisionModel(LlavaHfVisionModel): """Llava hf vision model.""" _arch = 'LlavaNextForConditionalGeneration' - def build_model(self): - from accelerate import init_empty_weights, load_checkpoint_and_dispatch - from accelerate.utils import get_balanced_memory, infer_auto_device_map - + def build_preprocessor(self): + super().build_preprocessor() + # build the model with empty weights. The model will be used in + # `preprocess` to get the image token number + from accelerate import init_empty_weights with init_empty_weights(), warnings.catch_warnings(): warnings.simplefilter('ignore') from transformers import LlavaNextForConditionalGeneration - model = LlavaNextForConditionalGeneration._from_config( + self.model = LlavaNextForConditionalGeneration._from_config( self.hf_config) + self.vl_model = self.model if not self.with_llm: - del model.language_model - for key in ['language_model']: - setattr(model, key, None) - else: - self.vl_model = model + del self.model.language_model + + def build_model(self): + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" + from accelerate import load_checkpoint_and_dispatch + from accelerate.utils import get_balanced_memory, infer_auto_device_map no_split_module_classes = ['CLIPEncoderLayer'] max_memory = get_balanced_memory( - model, + self.model, max_memory=self.max_memory, dtype=torch.half, no_split_module_classes=no_split_module_classes) device_map = infer_auto_device_map( - model, + self.model, no_split_module_classes=no_split_module_classes, max_memory=max_memory, dtype=torch.half) @@ -55,75 +60,128 @@ def build_model(self): with disable_logging(): load_checkpoint_and_dispatch( - model=model, + model=self.model, checkpoint=self.model_path, device_map=device_map if not self.with_llm else {'': 'cpu'}, no_split_module_classes=no_split_module_classes, dtype=torch.half) - model.eval() - self.model = model - # processor - processor = AutoProcessor.from_pretrained(self.model_path, - trust_remote_code=True) - if hasattr(processor, 'tokenizer'): - del processor.tokenizer - processor.prtokenizer = None - self.processor = processor.image_processor + self.model.eval() - @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to the spec of `super.preprocess()""" from transformers.models.llava_next.modeling_llava_next import \ image_size_to_num_patches - """forward.""" - processed_inputs = self.processor(images, - return_tensors='pt', - input_data_format='channels_last') - pixel_values = processed_inputs['pixel_values'].to( - device=self.model.device, dtype=self.model.dtype) - image_sizes = processed_inputs['image_sizes'].to( - device=self.model.device, dtype=self.model.dtype) - # ! infer image_num_patches from image_sizes - image_num_patches = [ - image_size_to_num_patches( - image_size=imsize, - grid_pinpoints=self.hf_config.image_grid_pinpoints, - patch_size=self.hf_config.vision_config.image_size, - ) for imsize in image_sizes - ] - # figure out if pixel_values is concatenated or stacked - if pixel_values.dim() == 5: - # stacking when input is - # (batch_size, num_patches, num_channels, height, width) - _pixel_values_list = [ - pix_val[:num_patch] - for pix_val, num_patch in zip(pixel_values, image_num_patches) + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + result = self.processor(image, + return_tensors='pt', + input_data_format='channels_last') + # ! infer image_num_patches from image_sizes + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.hf_config.image_grid_pinpoints, + patch_size=self.hf_config.vision_config.image_size, + ) for imsize in result['image_sizes'] ] - pixel_values = torch.cat(_pixel_values_list, dim=0) - elif pixel_values.dim() != 4: - # otherwise has to be stacked from list of - # (num_patches, num_channels, height, width) - raise ValueError(f'pixel_values of shape {pixel_values.shape}, ' - 'expect to be of 4 or 5 dimensions') - image_outputs = self.model.vision_tower.forward( - pixel_values, output_hidden_states=True) - image_features = image_outputs.hidden_states[ - self.hf_config.vision_feature_layer] - if self.hf_config.vision_feature_select_strategy == 'default': - image_features = image_features[:, 1:] - elif self.hf_config.vision_feature_select_strategy == 'full': - image_features = image_features - else: - raise ValueError( - 'Unexpected select feature strategy: ' - f'{self.hf_config.vision_feature_select_strategy}') - image_features = self.model.multi_modal_projector(image_features) - image_features = torch.split(image_features, image_num_patches, dim=0) - image_features, feature_lens = self.model.pack_image_features( - image_features, - image_sizes, - image_newline=self.model.image_newline, - ) - outputs = torch.split(image_features, - feature_lens.cpu().numpy().tolist(), - dim=0) - return outputs + + hidden_size = self.hf_config.text_config.hidden_size + fake_image_features = torch.zeros( + [image_num_patches[0], self.n_token_per_image, hidden_size]) + image_sizes = result['image_sizes'] + image_newline = torch.randn(self.hf_config.text_config.hidden_size) + strategy = self.hf_config.vision_feature_select_strategy + _, image_tokens = self.model.pack_image_features( + [fake_image_features], + image_sizes, + vision_feature_select_strategy=strategy, + image_newline=image_newline) + result.update( + dict(image_size=image.size, + image_patches=image_num_patches, + image_tokens=image_tokens, + image_token_id=0)) + outputs.append(result) + messages.append(dict(role='preprocess', content=outputs)) + return messages + + @torch.no_grad() + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'].to(device=self.model.device, + dtype=self.model.dtype) + for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.cat(pixel_values, dim=0) + image_sizes = [ + x['image_sizes'].to(device=self.model.device, + dtype=self.model.dtype) + for x in inputs[idx:idx + max_batch_size] + ] + image_sizes = torch.cat(image_sizes, dim=0) + image_num_patches = [ + x['num_patch'] for x in inputs[idx:idx + max_batch_size] + ] + image_num_patches = list(itertools.chain(*image_num_patches)) + # figure out if pixel_values is concatenated or stacked + if pixel_values.dim() == 5: + # stacking when input is + # (batch_size, num_patches, num_channels, height, width) + _pixel_values_list = [ + pix_val[:num_patch] for pix_val, num_patch in zip( + pixel_values, image_num_patches) + ] + pixel_values = torch.cat(_pixel_values_list, dim=0) + elif pixel_values.dim() != 4: + # otherwise has to be stacked from list of + # (num_patches, num_channels, height, width) + raise ValueError( + f'pixel_values of shape {pixel_values.shape}, ' + 'expect to be of 4 or 5 dimensions') + logger.info(f'vision forward shape: {pixel_values.shape}') + image_outputs = self.model.vision_tower.forward( + pixel_values, output_hidden_states=True) + image_features = image_outputs.hidden_states[ + self.hf_config.vision_feature_layer] + strategy = self.hf_config.vision_feature_select_strategy + if strategy == 'default': + image_features = image_features[:, 1:] + elif strategy == 'full': + image_features = image_features + else: + raise ValueError('Unexpected select feature strategy: ' + f'{strategy}') + image_features = self.model.multi_modal_projector(image_features) + image_features = torch.split(image_features, + image_num_patches, + dim=0) + image_features, feature_lens = self.model.pack_image_features( + image_features, + image_sizes, + vision_feature_select_strategy=strategy, + image_newline=self.model.image_newline, + ) + image_features = torch.split(image_features, + feature_lens.cpu().numpy().tolist(), + dim=0) + outputs.extend(image_features) + messages.append(dict(role='forward', content=outputs)) + return messages diff --git a/lmdeploy/vl/model/mini_gemeni.py b/lmdeploy/vl/model/mini_gemeni.py index 0565daeba5..eca70aca51 100644 --- a/lmdeploy/vl/model/mini_gemeni.py +++ b/lmdeploy/vl/model/mini_gemeni.py @@ -3,16 +3,18 @@ import os.path as osp import warnings from contextlib import contextmanager -from typing import List +from typing import Dict, List import torch -from PIL.Image import Image +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import (add_device_hook, disable_logging, disable_transformers_logging, hack_import_with) +logger = get_logger('lmdeploy') + def check_mini_gemini_install(): """check mini gemini install.""" @@ -22,8 +24,8 @@ def check_mini_gemini_install(): except ImportError: raise ImportError( 'To use MiniGeminiVisionModel, please install minigemini by ' - 'pip install git+https://github.com/dvlab-research/MGM.git' - ' --no-deps') + '`pip install git+https://github.com/dvlab-research/MGM.git' + ' --no-deps`') def _build_vision_tower(vision_tower_cfg, **kwargs): @@ -169,7 +171,15 @@ class MiniGeminiVisionModel(VisonModel): _arch = ['MiniGeminiLlamaForCausalLM', 'MGMLlamaForCausalLM'] + def build_preprocessor(self): + # pytorch engine will not support mini-gemini. Therefore, in order to + # reuse the previous code as much as possible, we do not extract image + # preprocessor from `build_model` function. + pass + def build_model(self): + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" check_mini_gemini_install() # empty init from accelerate import init_empty_weights @@ -193,13 +203,12 @@ def build_model(self): vision_tower.load_model() vision_tower_aux = model.get_vision_tower_aux() vision_tower_aux.load_model() + self.vl_model = model if not self.with_llm: del model.lm_head del model.model.embed_tokens del model.model.layers del model.model.norm - else: - self.vl_model = model from accelerate.utils import get_balanced_memory, infer_auto_device_map max_memory = get_balanced_memory( @@ -246,11 +255,35 @@ def build_model(self): self.image_processor = image_processor self.process_images = process_images + def preprocess(self, messages: List[Dict]) -> List[Dict]: + return messages + @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - outputs = [x.convert('RGB') for x in images] - image_tensor = self.process_images(outputs, self.image_processor, + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + images = [] + for message in messages: + if not isinstance(message['content'], List): + continue + _ = [ + x['image'] for x in message['content'] if x['type'] == 'image' + ] + assert len(_) == 1, f'MiniGeminiLlama accepts ONE input ' \ + f'image, but got {len(images)} images' + images.extend(_) + + image_tensor = self.process_images(images, self.image_processor, self.model.config) image_grid = getattr(self.model.config, 'image_grid', 1) if hasattr(self.model.config, 'image_size_aux'): @@ -301,15 +334,47 @@ def forward(self, images: List[Image]) -> List[torch.Tensor]: image.to(self.model.device, dtype=torch.float16) for image in image_tensor_aux ] + logger.info(f'vision forward bs: {len(image_tensor)}') else: image_tensor = image_tensor.to(self.model.device, dtype=torch.float16) image_tensor_aux = image_tensor_aux.to(self.model.device, dtype=torch.float16) - + logger.info(f'vision forward shape: {image_tensor.shape}') images_embeds = self.model.encode_images(image_tensor, image_tensor_aux) outputs = torch.split(images_embeds, 1, dim=0) outputs = [x.squeeze() for x in outputs] - return outputs + messages.append(dict(role='forward', cotent=outputs)) + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: + continue + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + item['text'] for item in message['content'] + if item['type'] == 'text' + ] + prompt = (IMAGE_TOKEN + '\n') * n_images + content[0] + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + assert 0, 'cogvlm is not supported by pytorch engine' + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/minicpmv.py b/lmdeploy/vl/model/minicpmv.py index 4e30190c1d..6b0c5f1508 100644 --- a/lmdeploy/vl/model/minicpmv.py +++ b/lmdeploy/vl/model/minicpmv.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +import itertools import warnings from typing import Dict, List import torch from PIL.Image import Image -from transformers import AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel @@ -19,8 +20,33 @@ class MiniCPMVModel(VisonModel): _arch = 'MiniCPMV' + def __init__(self, + model_path: str, + with_llm: bool = False, + max_memory: Dict[int, int] = None, + hf_config: AutoConfig = None, + backend: str = ''): + super().__init__(model_path, with_llm, max_memory, hf_config, backend) + if not hasattr(self.hf_config, 'version'): + raise ValueError('Can not find `version` in config.json. ' + 'Please checkout the latest model') + version = str(self.hf_config.version) + if version not in ['2.5', '2.6']: + raise ValueError( + f'Only support v2.5 and v2.6, but got version {version}') + self.version = version + + def build_preprocessor(self): + from transformers import AutoProcessor + self.processor = AutoProcessor.from_pretrained(self.model_path, + trust_remote_code=True) + self.image_processor = self.processor.image_processor + self._preprocess_func = (self._preprocess_v2_5 if self.version == '2.5' + else self._preprocess_v2_6) + def build_model(self): - """build model & load weights.""" + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" from accelerate import init_empty_weights with init_empty_weights(), warnings.catch_warnings(): warnings.simplefilter('ignore') @@ -29,10 +55,9 @@ def build_model(self): config.quantization_config = {} # disable vision part quantization model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + self.vl_model = model if not self.with_llm: del model.llm - else: - self.vl_model = model from accelerate import load_checkpoint_and_dispatch with disable_logging(): @@ -50,46 +75,11 @@ def build_model(self): device=model.resampler.proj.device) self.config = config self.model = model.eval() - self.init_forward_func() - - def init_forward_func(self): - if not hasattr(self.config, 'version'): - msg = 'LMDeploy only support `MiniCPM-V-2_6` and '\ - '`MiniCPM-Llama3-V-2_5`.\nCan not find `version` in config, ' \ - 'please consider update the huggingface model.' - logger.warn(msg) - - self._forward_func = self._forward_v2_5 - if hasattr(self.config, 'version'): - version = str(self.config.version) - if version == '2.6': - self._forward_func = self._forward_v2_6 - - if self._forward_func == self._forward_v2_5: - logger.info('using _forward_v2_5') - if not hasattr(self.model, 'slice_image'): - # adapt new code commit 287e3f85 (MiniCPM-Llama3-V-2_5) - from transformers import AutoProcessor - processor = AutoProcessor.from_pretrained( - self.model_path, trust_remote_code=True) - self.model.slice_image = processor.image_processor.slice_image - - def _reshape_by_patch(x): - out = x.cpu().numpy() - out = processor.image_processor.reshape_by_patch(out) - return torch.from_numpy(out).to(device=x.device) - - self.model.reshape_by_patch = _reshape_by_patch - - if self._forward_func == self._forward_v2_6: - logger.info('using _forward_v2_6') - from transformers import AutoProcessor - self.model.processor = AutoProcessor.from_pretrained( - self.model_path, trust_remote_code=True) def _get_slice_image(self, image: Image): slice_images = [] - source_image, patches, best_grid = self.model.slice_image(image) + source_image, patches, best_grid = self.image_processor.slice_image( + image) slice_images.append(source_image) if len(patches) > 0: for i in range(len(patches)): @@ -103,114 +93,198 @@ def _reshape_by_patch(self, slice_images): for slice_image in slice_images: slice_image = self.model.transform(slice_image) H, W = slice_image.shape[1:] - patches.append(self.model.reshape_by_patch(slice_image)) + slice_image = slice_image.numpy() + slice_image = self.image_processor.reshape_by_patch(slice_image) + slice_image = torch.from_numpy(slice_image) + patches.append(slice_image) H //= self.config.patch_size W //= self.config.patch_size tgt_sizes.append(torch.Tensor([H, W]).type(torch.int32)) return patches, tgt_sizes - def _forward_v2_5(self, images: List[Image], params: List[Dict] = None): - """forward for MiniCPM-Llama3-V-2_5.""" - patches = [] - tgt_sizes = [] - best_grids = [] - num_patches = [] - for image in images: - slice_images, best_grid = self._get_slice_image(image) - _patches, _tgt_sizes = self._reshape_by_patch(slice_images) - num_patches.append(len(_patches)) - patches.extend(_patches) - tgt_sizes.extend(_tgt_sizes) - best_grids.append(best_grid) - - patches = [ - x.to(dtype=torch.half, device=self.model.device) for x in patches - ] - patches = [x.flatten(end_dim=1).permute(1, 0) for x in patches] - tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) - max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) - all_pixel_values = torch.nn.utils.rnn.pad_sequence(patches, - batch_first=True, - padding_value=0.0) - B, L, _ = all_pixel_values.shape - all_pixel_values = all_pixel_values.permute(0, 2, - 1).reshape(B, 3, -1, L) - patch_attn_mask = torch.zeros((B, 1, max_patches), - dtype=torch.bool, - device=self.model.device) - for i in range(B): - patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True - vision_embedding = self.model.vpm( - all_pixel_values.type(torch.half), - patch_attention_mask=patch_attn_mask).last_hidden_state - vision_embedding = self.model.resampler(vision_embedding, tgt_sizes) - vision_embedding = torch.split(vision_embedding, num_patches, 0) + def _preprocess_v2_5(self, image: Image, params: Dict = None) -> Dict: + """image preprocessing for MiniCPM-Llama3-V-2_5.""" + slice_images, best_grid = self._get_slice_image(image) + # pixel_values, tgt_sizes are list of torch tensors + pixel_values, tgt_sizes = self._reshape_by_patch(slice_images) + num_patches = len(pixel_values) + return dict( + pixel_values=pixel_values, # a list + tgt_sizes=tgt_sizes, # a list + best_grid=best_grid, + num_patches=num_patches, + image_tokens=1, + image_token_id=0) + + def _preprocess_v2_6(self, image: Image, params: Dict = None) -> Dict: + """image preprocessing for MiniCPM-V-2_6.""" + max_slice_nums = self.image_processor.max_slice_nums + use_image_id = self.image_processor.use_image_id + max_slice_nums = params.get('max_slice_nums', max_slice_nums) + use_image_id = params.get('use_image_id', use_image_id) + outputs = self.image_processor(image, max_slice_nums=max_slice_nums) + pixel_values = outputs['pixel_values'][0] + num_patches = len(pixel_values) + pixel_values = [torch.as_tensor(x) for x in pixel_values] + tgt_sizes = outputs['tgt_sizes'][0] + tgt_sizes = [torch.as_tensor(x) for x in tgt_sizes] + grid = self.image_processor.get_sliced_grid( + image_size=image.size, max_slice_nums=max_slice_nums) + return dict( + pixel_values=pixel_values, # a list + tgt_sizes=tgt_sizes, # a list + best_grid=grid, + num_patches=num_patches, + image_tokens=1, + image_token_id=0, + use_image_id=use_image_id) + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to `super().preprocess() for spec.""" outputs = [] - for embeddings, grid in zip(vision_embedding, best_grids): - embeddings = embeddings.cpu() # n x d x h - outputs.append(dict(embeddings=embeddings, grid=grid)) + for i, message in enumerate(messages): + if message['role'] != 'user' or not isinstance( + message['content'], List): + continue + for item in message['content']: + if item['type'] == 'image': + image = item['image'].convert('RGB') + params = { + k: v + for k, v in item.items() if k not in {'type', 'image'} + } + result = self._preprocess_func(image, params) + outputs.append(result) + messages[i].update(dict(preprocess=outputs)) + return messages - return outputs + @torch.no_grad() + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. - def _forward_v2_6(self, images: List[Image], params: List[Dict] = None): - """forward for MiniCPM-V-2_6.""" - patches = [] - tgt_sizes = [] - best_grids = [] - num_patches = [] - max_slice_nums = self.model.processor.image_processor.max_slice_nums - use_image_id = self.model.processor.image_processor.use_image_id - for image, param in zip(images, params): - max_slice_nums = param.get('max_slice_nums', max_slice_nums) - use_image_id = param.get('use_image_id', use_image_id) - outputs = self.model.processor.image_processor( - image, max_slice_nums=max_slice_nums) - patches.extend(outputs['pixel_values'][0]) - num_patches.append(len(outputs['pixel_values'][0])) - tgt_sizes.extend(outputs['tgt_sizes'][0]) - grid = self.model.processor.image_processor.get_sliced_grid( - image_size=image.size, max_slice_nums=max_slice_nums) - best_grids.append(grid) - - patches = [ - torch.as_tensor(x).to(dtype=torch.half, device=self.model.device) - for x in patches + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + # collect preprocess results into a list + inputs = [] + inputs = [ + x['preprocess'] for x in messages if 'preprocess' in x.keys() ] - patches = [x.flatten(end_dim=1).permute(1, 0) for x in patches] - tgt_sizes = [torch.as_tensor(x) for x in tgt_sizes] - tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) - max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) - all_pixel_values = torch.nn.utils.rnn.pad_sequence(patches, + # flatten the list + inputs = list(itertools.chain(*inputs)) + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + tgt_sizes = [ + x['tgt_sizes'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + num_patches = [ + x['num_patches'] for x in inputs[idx:idx + max_batch_size] + ] + # flatten the list + tgt_sizes = list(itertools.chain(*tgt_sizes)) + pixel_values = list(itertools.chain(*pixel_values)) + pixel_values = [ + x.to(dtype=torch.half, device=self.model.device) + for x in pixel_values + ] + pixel_values = [ + x.flatten(end_dim=1).permute(1, 0) for x in pixel_values + ] + pixel_values = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True, padding_value=0.0) - B, L, _ = all_pixel_values.shape - all_pixel_values = all_pixel_values.permute(0, 2, - 1).reshape(B, 3, -1, L) - patch_attn_mask = torch.zeros((B, 1, max_patches), - dtype=torch.bool, - device=self.model.device) - for i in range(B): - patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True - vision_embedding = self.model.vpm( - all_pixel_values.type(torch.half), - patch_attention_mask=patch_attn_mask, - tgt_sizes=tgt_sizes).last_hidden_state - vision_embedding = self.model.resampler(vision_embedding, tgt_sizes) - vision_embedding = torch.split(vision_embedding, num_patches, 0) - outputs = [] - for embeddings, grid in zip(vision_embedding, best_grids): - embeddings = embeddings.cpu() # n x d x h - outputs.append( - dict(embeddings=embeddings, - grid=grid, - use_image_id=use_image_id)) + B, L, _ = pixel_values.shape + pixel_values = pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) + tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) + max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) + patch_attn_mask = torch.zeros((B, 1, max_patches), + dtype=torch.bool, + device=self.model.device) + logger.info(f'vision forward shape: {pixel_values.shape}') + if self.version == '2.5': + for j in range(B): + patch_attn_mask[j, :tgt_sizes[j][0] * + tgt_sizes[j][1]] = True + embeddings = self.model.vpm( + pixel_values.type(torch.half), + patch_attention_mask=patch_attn_mask).last_hidden_state + else: + for j in range(B): + patch_attn_mask[j, 0, :tgt_sizes[j][0] * + tgt_sizes[j][1]] = True + embeddings = self.model.vpm( + pixel_values.type(torch.half), + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes).last_hidden_state - return outputs + embeddings = self.model.resampler(embeddings, tgt_sizes) + embeddings = torch.split(embeddings, num_patches, 0) + for embedding in embeddings: + embedding = embedding.split(1, dim=0) + outputs.extend([x.squeeze() for x in embedding]) + messages.append(dict(role='forward', content=outputs)) + return messages - @torch.no_grad() - def forward(self, - images: List[Image], - params: List[Dict] = None) -> List[torch.Tensor]: - """forward.""" - images = [x.convert('RGB') for x in images] - return self._forward_func(images, params) + def proc_messages(self, messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + idx = 0 + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + if 'preprocess' not in message.keys(): + continue + prompts = [] + for x in message['preprocess']: + prompt = f'{IMAGE_TOKEN}' + if x.get('use_image_id', False): + prompt = f'{idx}' + prompt + idx += 1 + grid = x['best_grid'] + if grid is not None: + if self.version == '2.5': + slice = '\n'.join( + [f'{IMAGE_TOKEN}' * grid[0]] * + grid[1]) + prompt = f'{prompt}{slice}\n' + elif self.version == '2.6': + slice = '\n'.join( + [f'{IMAGE_TOKEN}' * grid[0]] * + grid[1]) + prompt = prompt + slice + prompt += '\n' + else: + prompt = (prompt + + '\n' if self.version == '2.6' else prompt) + prompts.append(prompt) + content = [ + x['text'] for x in message['content'] if x['type'] == 'text' + ] + prompt = ''.join(prompts) + content[0] + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/mllama.py b/lmdeploy/vl/model/mllama.py index db0a0e9cbf..0cae71cd6c 100644 --- a/lmdeploy/vl/model/mllama.py +++ b/lmdeploy/vl/model/mllama.py @@ -2,192 +2,10 @@ from typing import Dict, List -import torch -import torch.nn.functional as F -from PIL.Image import Image -from transformers.modeling_outputs import BaseModelOutput -from transformers.models.mllama.modeling_mllama import MllamaPreTrainedModel - from lmdeploy.vl.model.base import VISION_MODELS, VisonModel -from lmdeploy.vl.model.utils import disable_logging - - -class MllamaVisionModelPatch(MllamaPreTrainedModel): - - def apply_class_embedding(self, - hidden_state: torch.Tensor) -> torch.Tensor: - batch_size, _, hidden_size = hidden_state.shape - class_embedding = self.class_embedding.expand(batch_size, 1, - hidden_size) - class_embedding = class_embedding.to(hidden_state.device) - hidden_state = torch.cat([class_embedding, hidden_state], dim=1) - return hidden_state - - def forward( - self, - pixel_values: torch.Tensor, - aspect_ratio_ids: torch.Tensor, - aspect_ratio_mask: torch.Tensor, - output_attentions: bool = None, - output_hidden_states: bool = None, - return_dict: bool = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # noqa - - batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape # noqa - - pixel_values = pixel_values.reshape( - batch_size * num_concurrent_media * num_tiles, num_channels, - height, width) - aspect_ratio_ids = aspect_ratio_ids.reshape( - batch_size * num_concurrent_media, -1) - - # Patch embedding - patch_embeds = self.patch_embedding( - pixel_values.to(self.dtype).to(self.device)) - hidden_state = patch_embeds.flatten(2).transpose(1, 2) - - # Tile embeddings - _, num_patches, dim = hidden_state.shape - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, -1, dim) - hidden_state = self.pre_tile_positional_embedding( - hidden_state, aspect_ratio_ids) - - # Add cls token - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media * num_tiles, num_patches, dim) - hidden_state = self.apply_class_embedding(hidden_state) - num_patches += 1 - - # Position embeddings - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, num_patches, dim) - hidden_state = self.gated_positional_embedding(hidden_state, - aspect_ratio_ids) - - hidden_state = self.layernorm_pre(hidden_state) - - # Compute the number of tokens to pad - num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 - # Compute padding tuple for pad function - padding = ( - 0, 0, 0, num_padding_patches - ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) - # Pad the tensor - hidden_state = F.pad(hidden_state, padding, mode='constant', value=0) - slice_index = -num_padding_patches if num_padding_patches > 0 else None - - # Prepare attention mask - attention_mask = aspect_ratio_mask.reshape( - batch_size * num_concurrent_media, -1) - from transformers.models.mllama.modeling_mllama import \ - _prepare_aspect_ratio_attention_mask - attention_mask = _prepare_aspect_ratio_attention_mask( - aspect_ratio_mask=attention_mask, - num_patches=self.num_patches, - target_length=hidden_state.shape[2], - dtype=self.dtype, - ) - - # Apply encoder - hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, - dim) - output = self.transformer( - hidden_state, - attention_mask=attention_mask, - output_hidden_states=True, - output_attentions=output_attentions, - ) - hidden_state = output[0] - - hidden_state = self.layernorm_post(hidden_state) - - # Apply global encoder - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, - num_patches + num_padding_patches, - dim) - hidden_state = self.post_tile_positional_embedding( - hidden_state, aspect_ratio_ids) - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, - num_tiles * (num_patches + num_padding_patches), dim) - global_output = self.global_transformer( - hidden_state, - attention_mask=attention_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - ) - hidden_state = global_output[0] - - # Remove padding form hidden state - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, - num_patches + num_padding_patches, - dim) - hidden_state = hidden_state[:, :, :slice_index] - hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, - num_tiles, num_patches, dim) - - # Collect intermediate layer outputs from encoder output - all_intermediate_hidden_states = output[1] - # rewrite to sync device during accelerate pipeline parallel - device = hidden_state.device - all_intermediate_hidden_states = [ - s.to(device) for s in all_intermediate_hidden_states - ] - intermediate_hidden_states = torch.stack( - all_intermediate_hidden_states, dim=-1) - intermediate_hidden_states = intermediate_hidden_states[ - ..., self.intermediate_layers_indices] - - # Remove padding from intermediate hidden states - intermediate_hidden_states = intermediate_hidden_states.reshape( - batch_size * num_concurrent_media, num_tiles, - num_patches + num_padding_patches, -1) - intermediate_hidden_states = intermediate_hidden_states[:, :, : - slice_index] - intermediate_hidden_states = intermediate_hidden_states.reshape( - batch_size, num_concurrent_media, num_tiles, num_patches, -1) - - # Concatenate final hidden state and intermediate hidden states - hidden_state = torch.cat([hidden_state, intermediate_hidden_states], - dim=-1) - - if output_hidden_states: - hidden_states = tuple(all_intermediate_hidden_states) + tuple( - global_output[1]) - else: - hidden_states = None - - if output_attentions: - # global transformer in contrast to `self.transformer` doesn't - # always return hidden states so we might go index out-of-range - global_attn = tuple( - global_output[2]) if output_hidden_states else tuple( - global_output[1]) - attentions = tuple(output[2]) + global_attn - else: - attentions = None - - if not return_dict: - return tuple(v for v in [hidden_state, hidden_states, attentions] - if v is not None) - - return BaseModelOutput( - last_hidden_state=hidden_state, - hidden_states=hidden_states, - attentions=attentions, - ) def check_transformers(): - """check qwen_vl_utils.""" try: from transformers import MllamaForConditionalGeneration # noqa: F401 except ImportError: @@ -202,85 +20,60 @@ class MllamaVLModel(VisonModel): _arch = 'MllamaForConditionalGeneration' - def build_model(self): - check_transformers() - - from transformers.models.mllama.modeling_mllama import \ - MllamaVisionModel - MllamaVisionModel.forward = MllamaVisionModelPatch.forward - MllamaVisionModel.apply_class_embedding = MllamaVisionModelPatch.apply_class_embedding # noqa - from accelerate import init_empty_weights - with init_empty_weights(): - config = self.hf_config - config.quantization_config = {} # disable vision part quantization - # disable accelerate check_tied_parameters_in_config - config.tie_word_embeddings = False - from transformers import MllamaForConditionalGeneration - model = MllamaForConditionalGeneration._from_config(config) - if not self.with_llm: - del model.language_model - else: - self.vl_model = model - - from accelerate import load_checkpoint_and_dispatch - with disable_logging(): - load_checkpoint_and_dispatch( - model=model, - checkpoint=self.model_path, - device_map='auto' if not self.with_llm else {'': 'cpu'}, - max_memory=self.max_memory, - no_split_module_classes=[ - 'MllamaPrecomputedPositionEmbedding', - 'MllamaPrecomputedAspectRatioEmbedding', - 'MllamaVisionEncoderLayer' - ], - dtype=config.torch_dtype) - - self.model = model.eval() - - # processor + def build_preprocessor(self): from transformers import AutoProcessor self.processor = AutoProcessor.from_pretrained(self.model_path) self.image_token_id = 128256 - @torch.no_grad() - def forward(self, - images: List[Image], - params: List[Dict] = None) -> List[torch.Tensor]: - """forward.""" - # only support image input - if params is not None: - assert len(images) == len( - params), 'different length of images and params' + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to the spec of `super().preprocess`""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + results = self.processor.image_processor(images=image, + return_tensors='pt') + results.update(image_size=image.size, + image_tokens=1, + image_token_id=self.image_token_id) + outputs.append(results) + messages.append(dict(role='preprocess', content=outputs)) + return messages + + def build_model(self): + check_transformers() + if self.with_llm: + from transformers import MllamaForConditionalGeneration + model = MllamaForConditionalGeneration.from_pretrained( + self.model_path, device_map='cpu') + self.vl_model = model else: - params = [{}] * len(images) - # resize images with abnormal shape - # TODO try catch image feature extraction in pipeline and - # throw error back to users - for i, image in enumerate(images): - size = image.size - if any([s < 3 for s in size]): - images[i] = image.resize([s * 3 for s in size]) - image_inputs = self.processor.image_processor(images=images, - return_tensors='pt') - pixel_values = image_inputs['pixel_values'].to( - self.model.vision_model.device) - pixel_values = pixel_values.type(self.model.vision_model.dtype) - aspect_ratio_ids = image_inputs['aspect_ratio_ids'].to( - self.model.vision_model.device) - aspect_ratio_mask = image_inputs['aspect_ratio_mask'].to( - self.model.vision_model.device) - vision_outputs = self.model.vision_model( - pixel_values=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - output_hidden_states=False, - output_attentions=False, - return_dict=True) - cross_attention_states = vision_outputs[0] - cross_attention_states = self.model.multi_modal_projector( - cross_attention_states) - _, bsz, _, _, image_token_dim = tuple(cross_attention_states.shape) - cross_attention_states = cross_attention_states.view( - bsz, -1, image_token_dim).split([1] * len(images)) - return cross_attention_states + raise NotImplementedError('turbomind has not supported mllama yet') + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '<|image|>' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: + continue + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + item['text'] for item in message['content'] + if item['type'] == 'text' + ] + prompt = (IMAGE_TOKEN) * n_images + content[0] + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/molmo.py b/lmdeploy/vl/model/molmo.py index 9abae7a309..eccf62ebb6 100644 --- a/lmdeploy/vl/model/molmo.py +++ b/lmdeploy/vl/model/molmo.py @@ -3,11 +3,9 @@ from typing import Dict, List import torch -from PIL.Image import Image from transformers import AutoModelForCausalLM, AutoProcessor from lmdeploy.utils import get_logger -from lmdeploy.vl.constants import IMAGE_TOKEN from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import disable_logging @@ -20,20 +18,26 @@ class MolmoVisionModel(VisonModel): _arch = 'MolmoForCausalLM' + def build_preprocessor(self): + self.processor = AutoProcessor.from_pretrained(self.model_path, + trust_remote_code=True, + torch_dtype=torch.half, + device_map='auto') + def build_model(self): - """Load model.""" + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" from accelerate import init_empty_weights, load_checkpoint_and_dispatch with init_empty_weights(): - config = self.hf_config - model = AutoModelForCausalLM.from_config(config, + model = AutoModelForCausalLM.from_config(self.hf_config, trust_remote_code=True) + + self.vl_model = model if not self.with_llm: # Remove nn modules other than embedding from the LLM model for key in ['emb_drop', 'ln_f', 'blocks', 'ff_out']: del model.model.transformer[key] - self.token_embedding = model.model.transformer.wte - else: - self.vl_model = model + self.token_embedding = model.model.transformer.wte with disable_logging(): load_checkpoint_and_dispatch( @@ -43,118 +47,161 @@ def build_model(self): max_memory=self.max_memory, no_split_module_classes=[ 'ResidualAttentionBlock', 'Embedding' - ]) + ], + dtype=torch.half) # We need eval mode to freeze the weights in model, thus, # avoid randomness in inference. self.model = model.eval() - self.config = config - self.processor = AutoProcessor.from_pretrained(self.model_path, - trust_remote_code=True, - torch_dtype='auto', - device_map='auto') + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to the `super.preprocess() for spec.""" + for i, message in enumerate(messages): + if not isinstance(message['content'], List): + continue + images = [ + x['image'] for x in message['content'] if x['type'] == 'image' + ] + content = [ + x['text'] for x in message['content'] if x['type'] == 'text' + ] + prompt = f' User: {content[0]}' + tokens = self.processor.tokenizer.encode(prompt, + add_special_tokens=False) + # preprocess images. The output is a dict, which is + # { + # 'input_ids': torch.Tensor, + # 'images': torch.Tensor, # (n_patch, d_model) + # 'image_input_idx': torch.Tensor, # (n_patch, d_model) + # 'image_masks': torch.Tensor, # (n_patch, d_model) + # } + result = self.processor.process(images=images, tokens=tokens) + # remove the bos from input_ids which is prepended by molmo's + # processor + input_ids = result['input_ids'][1:] + result.update(input_ids=input_ids) + messages[i].update(preprocess=result) + return messages @torch.no_grad() def forward(self, - images: List[Image], - params: List[Dict] = None) -> List[Dict]: - """forward the model with given input. + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. Args: - images (List): [None] it is not used - params (List): the inputs after precessing GPT4V messages in - `MolmoChatTemplateWrapper`. Its format is like the following: - [[ - {'role': 'user', 'content': 'user prompt'}, - {'role': 'asssistant', 'content': 'assistant prompt'}, - {'role': 'user', 'content': 'user prompt', 'images': [PIL image list]}, - ... - ]] - """ # noqa - - messages = params[0] - assert isinstance(messages, List) - # append an assistant message to `messages` - messages.append(dict(role='assistant', content='')) + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + for i, message in enumerate(messages): + if 'preprocess' not in message.keys(): + continue + inputs = message['preprocess'] + # get input_ids of embedding + inputs = { + k: v.to(self.model.device).unsqueeze(0) + for k, v in inputs.items() + } + input_ids = inputs['input_ids'] + # (batch_size, num_image, num_patch, d_model) + images = inputs['images'] + # (batch_size, num_image, num_patch) + image_input_idx = inputs['image_input_idx'] + image_masks = inputs['image_masks'] + batch_size, seq_len = input_ids.size() + assert batch_size == 1 + input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) + embeddings = self.model.model.transformer.wte(input_ids) + images = images.to(self.model.dtype) + image_masks = image_masks.to(self.model.dtype) + logger.info(f'vision forward shape: {images.shape}') + image_features, _ = self.model.model.vision_backbone( + images, image_masks) + num_image, num_patch = image_features.shape[1:3] + assert image_input_idx.shape == (batch_size, num_image, num_patch) + + # insert the image feature into the embedding. + image_features = image_features.view(batch_size, + num_image * num_patch, -1) + image_input_idx = image_input_idx.view(batch_size, + num_image * num_patch) + valid = image_input_idx >= 0 + batch_idx = torch.arange(batch_size, device=embeddings.device) + batch_idx = torch.tile(batch_idx[:, None], + [1, image_features.shape[1]]) + image_features = image_features.to(embeddings.device) + # Since we remove bos_id from input_ids during `preprocess`, + # the index `image_input_idx[valid]` should be shift to left + # by subtracting 1 + index = image_input_idx[valid] - 1 + embeddings[batch_idx[valid], index] += image_features[valid] + assert embeddings.shape[:2] == (batch_size, seq_len) + messages[i].update( + dict(forward=dict(input_ids=input_ids.flatten(), + embeddings=embeddings))) + return messages + + @staticmethod + def proc_messages(messages): + prompt = [] + IMAGE_TOKEN = '' + for message in messages: + role, content = message['role'], message['content'] + if isinstance(content, List): + n_images = len([1 for x in content if x['type'] == 'image']) + content = [x['text'] for x in content if x['type'] == 'text'] + prompt.append(' User: ' + (IMAGE_TOKEN + '\n') * n_images + + content[0]) + else: + if role == 'user': + prompt.append(f' User: {content}') + elif role == 'assistant': + prompt.append(f' Assistant:{content}') + else: + assert 0, f'molmo does not support role {role}, message is {message}' # noqa + prompt.append(' Assistant:') + return ''.join(prompt) + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + assert 0, 'molmo is not supported by pytorch engine' + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): # results is a list of tuple(input_ids, embeddings) results = [] - # the concat prompt. It is not used during inference but to adhere the - # interface definition of `_get_prompt_input` in `class VLAsyncEngine` - prompts = '' # Prepend BOS # qwen2 and olmo do not have a BOS, and instead use EOS as a generic # separator token. bos = (self.processor.tokenizer.bos_token_id or self.processor.tokenizer.eos_token_id) results.append(([bos], None)) + for i, message in enumerate(messages): - if 'images' in message.keys(): - prompts += ' User: ' + (IMAGE_TOKEN + '\n') * len( - message['images']) + message['content'] - prompt = f' User: {message["content"]}' - tokens = self.processor.tokenizer.encode( - prompt, add_special_tokens=False) - # preprocess images. The output is a dict - inputs = self.processor.process(images=message['images'], - tokens=tokens) - inputs = { - k: v.to(self.model.device).unsqueeze(0) - for k, v in inputs.items() - } - input_ids = inputs['input_ids'] - # remove the bos from input_ids which is prepended by molmo's - # processor - input_ids = input_ids[:, 1:] - images = inputs[ - 'images'] # (batch_size, num_image, num_patch, d_model) - image_input_idx = inputs[ - 'image_input_idx'] # (batch_size, num_image, num_patch) - image_masks = inputs['image_masks'] - batch_size, seq_len = input_ids.size() - assert batch_size == 1 - - # Get embeddings of input. - if input_ids is not None: - input_ids = input_ids * (input_ids != -1).to( - input_ids.dtype) - embeddings = self.model.model.transformer.wte(input_ids) - image_features, _ = self.model.model.vision_backbone( - images, image_masks) - num_image, num_patch = image_features.shape[1:3] - assert image_input_idx.shape == (batch_size, num_image, - num_patch) - - # insert the image feature into the embedding. - image_features = image_features.view(batch_size, - num_image * num_patch, -1) - image_input_idx = image_input_idx.view(batch_size, - num_image * num_patch) - - valid = image_input_idx >= 0 - batch_idx = torch.arange(batch_size, device=embeddings.device) - batch_idx = torch.tile(batch_idx[:, None], - [1, image_features.shape[1]]) - image_features = image_features.to(embeddings.device) - embeddings[batch_idx[valid], - image_input_idx[valid]] += image_features[valid] - assert embeddings.shape[:2] == (batch_size, seq_len) - results.append((input_ids.flatten().tolist(), embeddings)) + prompt = '' + role, content = message['role'], message['content'] + if isinstance(content, List): + forward_result = message.pop('forward') + input_ids = forward_result['input_ids'] + embeddings = forward_result['embeddings'] + results.append((input_ids.tolist(), embeddings)) else: - role = message['role'] - content = message['content'] - assert isinstance(content, str) - prompt = '' if role == 'user': prompt = f' User: {content}' elif role == 'assistant': prompt = f' Assistant:{content}' else: assert 0, f'molmo does not support role {role}, message is {message}' # noqa + if i == len(messages) - 1: + # the last message + assert role == 'user', f'the role of last message is expected to be user, but got {role}' # noqa + prompt += ' Assistant:' + if prompt: input_ids = self.processor.tokenizer.encode( prompt, add_special_tokens=False) results.append((input_ids, None)) - prompts += prompt # concat input_ids from results, calculate the range in the input_ids # where embeddings will be copied to @@ -169,9 +216,9 @@ def forward(self, input_embedding_ranges.append((start, end)) input_ids += _input_ids start += len(_input_ids) - return [ - dict(prompt=prompts, - input_ids=input_ids, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges) - ] + + prompt = self.proc_messages(messages) + return dict(prompt=prompt, + input_ids=input_ids, + input_embeddings=input_embeddings, + input_embedding_ranges=input_embedding_ranges) diff --git a/lmdeploy/vl/model/phi3_vision.py b/lmdeploy/vl/model/phi3_vision.py index 032b8404da..ff00b5d1d9 100644 --- a/lmdeploy/vl/model/phi3_vision.py +++ b/lmdeploy/vl/model/phi3_vision.py @@ -1,198 +1,48 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings -from typing import List +from typing import Dict, List -import torch -from PIL.Image import Image from transformers import AutoProcessor -from lmdeploy.vl.model.base import VISION_MODELS, VisonModel -from lmdeploy.vl.model.utils import disable_logging - - -# from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py # noqa E501 -def _process_image_embedding(self, pixel_values: torch.Tensor, - image_sizes: torch.Tensor): - """process image embedding.""" - img_embeds = pixel_values - img_sizes = image_sizes - target_device = pixel_values.device - target_dtype = pixel_values.dtype - if self.use_hd_transform and img_sizes is not None and len(img_sizes): - assert img_embeds.ndim == 5, f'img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform' # noqa E501 - # img_embeds: (num_images, max_num_crops, 3, H, W) - # img_sizes: (num_images, 2).view(1, -1) - - bs = img_embeds.shape[0] - # Nx(HW)xC - img_features = self.get_img_features(img_embeds.flatten(0, 1)) - base_feat_height = base_feat_width = int(img_features.shape[1]**0.5) - - assert base_feat_height == 24 and base_feat_width == 24, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect 24x24 features for hd transform' # noqa E501 - - # bs x max_num_crops x (24x24) x C - img_features = img_features.view(bs, -1, - base_feat_height * base_feat_width, - self.image_dim_out) - C = self.image_dim_out - H = base_feat_height - - output_imgs = [] - output_len = [] - # training is tensor, inference is list - if isinstance(img_sizes, torch.Tensor): - img_sizes = img_sizes.view(-1, 2) - for _bs in range(bs): - h, w = img_sizes[_bs] - h = h // 336 - w = w // 336 - B_ = h * w - - # 1 x (24x24) x 1024 - global_img_feature = img_features[_bs, :1] - - # 1 x 12 x 12 x 4096 - glb_img = global_img_feature.reshape(1, H, H, C).reshape( - 1, H // 2, 2, H // 2, 2, - C).contiguous().permute(0, 1, 3, 2, 4, - 5).reshape(1, H // 2, H // 2, - 4 * C).contiguous() - temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1) - - # 1 x 156 x 4096 - glb_img = torch.cat([glb_img, temp_glb_GN], - dim=2).reshape(1, -1, 4 * C) - - # (max_num_crops-1) x (12x12) x C - sub_img = img_features[_bs, 1:] - # 16x574x1024 - # get rid of padding sub_img - sub_img = sub_img[:B_] - - # (num_crops, 12, 2, 12, 2, 1024)->(num_crops, 12, 12, 2, 2, 1024) - # -> (num_crops, 12*12, 4*1024) - sub_img = sub_img.reshape(B_, H, H, C).reshape( - B_, H // 2, 2, H // 2, 2, - C).contiguous().permute(0, 1, 3, 2, 4, - 5).reshape(B_, -1, 4 * C).contiguous() - sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute( - 0, 1, 3, 2, 4, 5).reshape(1, h * 12, w * 12, 4 * C) - temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1) - sub_img = torch.cat([sub_img, temp_sub_GN], - dim=2).reshape(1, -1, 4 * C) - # (1, num_img_tokens, 1024*4) - - # glb + sub - if self.hd_transform_order == 'glb_sub': - output_imgs.append( - torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) - elif self.hd_transform_order == 'sub_glb': - output_imgs.append( - torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) - else: - raise NotImplementedError( - f'hd_transform_order = {self.hd_transform_order}' - ) # noqa E501 - - temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12) - assert temp_len == output_imgs[-1].shape[ - 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' # noqa E501 - output_len.append(temp_len) - - img_set_tensor = [] - for _output_img in output_imgs: - img_feature_proj = self.img_projection( - _output_img.to(target_device).to(target_dtype)) - img_set_tensor.append(img_feature_proj) - elif img_embeds.ndim == 4: - tt = (self.get_img_features(img_embeds).to(target_device).to( - target_dtype).reshape(-1, self.image_dim_out)) - img_set_tensor = self.img_projection(tt) # adapted visual features. - elif img_embeds.ndim == 3: - tt = (img_embeds.to(target_device).to(target_dtype).view( - -1, self.image_dim_out)) - img_set_tensor = self.img_projection(tt) # adapted visual features. - else: - raise NotImplementedError - return img_set_tensor +from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel @VISION_MODELS.register_module() -class Phi3VisionModel(VisonModel): - """Llava hf vision model.""" +class Phi3VisionModel(LlavaHfVisionModel): + """Phi3-vision model.""" _arch = 'Phi3VForCausalLM' - def build_model(self): - from accelerate import init_empty_weights, load_checkpoint_and_dispatch - from accelerate.utils import get_balanced_memory, infer_auto_device_map - - with init_empty_weights(), warnings.catch_warnings(): - warnings.simplefilter('ignore') - from transformers import AutoModelForCausalLM - model = AutoModelForCausalLM.from_config(self.hf_config, - trust_remote_code=True) - if not self.with_llm: - del model.lm_head - del model.model.layers - del model.model.norm - del model.model.embed_tokens - del model.model.vision_embed_tokens.wte - else: - self.vl_model = model - - no_split_module_classes = ['CLIPEncoderLayer'] - max_memory = get_balanced_memory( - model, - max_memory=self.max_memory, - dtype=torch.half, - no_split_module_classes=no_split_module_classes) - device_map = infer_auto_device_map( - model, - no_split_module_classes=no_split_module_classes, - max_memory=max_memory, - dtype=torch.half) - same_device_keys = [('model.vision_embed_tokens.img_projection', - 'model.vision_embed_tokens.sub_GN', - 'model.vision_embed_tokens.glb_GN')] - for keys in same_device_keys: - keys = [k for k in keys if k in device_map] - if len(keys) <= 1: - continue - for k in keys[1:]: - device_map[k] = device_map[keys[0]] - - with disable_logging(): - load_checkpoint_and_dispatch( - model=model, - checkpoint=self.model_path, - device_map=device_map if not self.with_llm else {'': 'cpu'}, - no_split_module_classes=no_split_module_classes, - dtype=torch.half) - - model.eval() - self.model = model - # processor + def build_preprocessor(self): processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True) if hasattr(processor, 'tokenizer'): del processor.tokenizer - processor.prtokenizer = None - self.processor = processor.image_processor + processor.tokenizer = None self.processor = processor - @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - process_outputs = self.processor.image_processor( - images, return_tensors='pt').to(device=self.model.device, - dtype=self.model.dtype) - pixel_values = process_outputs['pixel_values'] - image_sizes = process_outputs['image_sizes'] - image_features = _process_image_embedding( - self.model.model.vision_embed_tokens, - pixel_values=pixel_values, - image_sizes=image_sizes) - outputs = [x.squeeze() for x in image_features] - return outputs + def build_model(self): + if self.with_llm: + from transformers import AutoModelForCausalLM + self.vl_model = AutoModelForCausalLM.from_pretrained( + self.model_path, device_map='cpu', trust_remote_code=True) + else: + raise NotImplementedError('turbomind has not supported phi3v yet') + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to `super.preprocess() for spec.""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + result = self.processor.image_processor(image, return_tensors='pt') + h = result['image_sizes'][0][0].item() // 336 + w = result['image_sizes'][0][1].item() // 336 + image_tokens = int((h * w + 1) * 144 + 1 + (h + 1) * 12) + result.update( + dict(image_size=image.size, + image_tokens=image_tokens, + image_token_id=0)) + outputs.append(result) + messages.append(dict(role='preprocess', content=outputs)) + return messages diff --git a/lmdeploy/vl/model/qwen.py b/lmdeploy/vl/model/qwen.py index 3968f27d97..49631ccf35 100644 --- a/lmdeploy/vl/model/qwen.py +++ b/lmdeploy/vl/model/qwen.py @@ -1,14 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List +from typing import Dict, List import torch -from PIL.Image import Image from transformers import AutoModelForCausalLM +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import disable_logging +logger = get_logger('lmdeploy') + @VISION_MODELS.register_module() class QwenVisionModel(VisonModel): @@ -16,19 +18,33 @@ class QwenVisionModel(VisonModel): _arch = 'QWenLMHeadModel' + def build_preprocessor(self): + from torchvision import transforms + from torchvision.transforms import InterpolationMode + mean = (0.48145466, 0.4578275, 0.40821073) + std = (0.26862954, 0.26130258, 0.27577711) + image_size = self.hf_config.visual['image_size'] + self.image_transform = transforms.Compose([ + transforms.Resize((image_size, image_size), + interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + def build_model(self): + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" from accelerate import init_empty_weights with init_empty_weights(): config = self.hf_config config.quantization_config = {} # disable vision part quantization model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + self.vl_model = model if not self.with_llm: del model.lm_head for key in ['wte', 'h', 'ln_f']: setattr(model.transformer, key, None) - else: - self.vl_model = model from accelerate.utils import get_balanced_memory, infer_auto_device_map max_memory = get_balanced_memory( @@ -60,13 +76,86 @@ def build_model(self): self.model = model.transformer.visual.eval() + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to `super.preprocess() for spec.""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + pixel_values = self.image_transform(image) + outputs.append( + dict(pixel_values=pixel_values, + image_size=image.size, + image_tokens=256, + image_token_id=0)) + messages.append(dict(role='preprocess', content=outputs)) + return messages + @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - outputs = [x.convert('RGB') for x in images] - outputs = [self.model.image_transform(x) for x in outputs] - outputs = torch.stack(outputs, dim=0) - outputs = self.model(outputs) - outputs = torch.split(outputs, 1, dim=0) - outputs = [x.squeeze() for x in outputs] - return outputs + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.stack(pixel_values, dim=0) + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.model(pixel_values) + feats = torch.split(feats, 1, dim=0) + outputs.extend([x.squeeze() for x in feats]) + messages.append(dict(role='forward', content=outputs)) + return messages + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: + continue + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + x['text'] for x in message['content'] if x['type'] == 'text' + ] + prompt = content[0] + if IMAGE_TOKEN in prompt: + pass + else: + prompt = ''.join([ + f'Picture {str(i)}:{IMAGE_TOKEN}\n' + for i in range(n_images) + ]) + prompt + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/qwen2.py b/lmdeploy/vl/model/qwen2.py index 3eb3c1541c..ed9da332e0 100644 --- a/lmdeploy/vl/model/qwen2.py +++ b/lmdeploy/vl/model/qwen2.py @@ -1,12 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. - from typing import Dict, List import torch -from PIL.Image import Image from lmdeploy.vl.model.base import VISION_MODELS, VisonModel -from lmdeploy.vl.model.utils import disable_logging def check_qwen_vl_deps_install(): @@ -15,7 +12,7 @@ def check_qwen_vl_deps_install(): import qwen_vl_utils # noqa: F401 except ImportError: raise ImportError( - 'please install qwen_vl_utils by pip install qwen_vl_utils' # noqa: E501 + 'please install qwen_vl_utils by `pip install qwen_vl_utils`' # noqa: E501 ) try: from transformers import Qwen2VLForConditionalGeneration # noqa: F401 @@ -31,85 +28,105 @@ class Qwen2VLModel(VisonModel): _arch = 'Qwen2VLForConditionalGeneration' + def build_preprocessor(self): + check_qwen_vl_deps_install() + from transformers import AutoProcessor + self.processor = AutoProcessor.from_pretrained(self.model_path) + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to `super().preprocess()` for spec.""" + from qwen_vl_utils import process_vision_info + + images = self.collect_images(messages) + optional_keys = { + 'resized_height', 'resized_width', 'min_pixels', 'max_pixels' + } + outputs = [] + for image, params in images: + image = image.convert('RGB') + + item = dict(type='image', image=image) + item.update({ + key: params[key] + for key in params.keys() if key in optional_keys + }) + image_inputs, _ = process_vision_info([dict(content=[item])]) + result = self.processor.image_processor(images=image_inputs, + videos=None, + return_tensors='pt') + merge_length = self.processor.image_processor.merge_size**2 + image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length + result.update( + dict(image_size=image.size, + image_tokens=image_tokens, + image_token_id=0)) + outputs.append(result) + messages.append(dict(role='preprocess', content=outputs)) + return messages + def build_model(self): check_qwen_vl_deps_install() from transformers import Qwen2VLForConditionalGeneration if self.with_llm: - model = Qwen2VLForConditionalGeneration.from_pretrained( - self.hf_config._name_or_path, trust_remote_code=True) - model.half() - self.vl_model = model + self.vl_model = Qwen2VLForConditionalGeneration.from_pretrained( + self.model_path, device_map='cpu') else: - from accelerate import init_empty_weights - with init_empty_weights(): - config = self.hf_config - config.quantization_config = { - } # disable vision part quantization - # disable accelerate check_tied_parameters_in_config - # for Qwen2-VL-2B-Instruct - config.tie_word_embeddings = False - - model = Qwen2VLForConditionalGeneration._from_config(config) - del model.model - del model.lm_head - model.half() - from accelerate import load_checkpoint_and_dispatch - with disable_logging(): - load_checkpoint_and_dispatch( - model=model, - checkpoint=self.model_path, - device_map='auto' if not self.with_llm else {'': 'cpu'}, - max_memory=self.max_memory, - no_split_module_classes=['Qwen2VLVisionBlock'], - dtype=torch.half) - - self.model = model.eval() - - # processor - from transformers import AutoProcessor - self.processor = AutoProcessor.from_pretrained(self.model_path) + raise NotImplementedError( + 'turbomind has not supported qwen2-vl yet') @torch.no_grad() def forward(self, - images: List[Image], - params: List[Dict] = None) -> List[torch.Tensor]: - """forward.""" - # only support image input - if params is not None: - assert len(images) == len( - params), 'different length of images and params' - else: - params = [{}] * len(images) + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. - from qwen_vl_utils import process_vision_info - images = [x.convert('RGB') for x in images] - content = [] - optional_keys = [ - 'resized_height', 'resized_width', 'min_pixels', 'max_pixels' - ] - for image, param in zip(images, params): - item = dict(type='image', image=image) - item.update({k: param[k] for k in optional_keys if k in param}) - content.append(item) - messages = [dict(content=content)] - image_inputs, _ = process_vision_info(messages) - image_inputs = self.processor.image_processor(images=image_inputs, - videos=None, - return_tensors='pt') - pixel_values = image_inputs['pixel_values'].to( - self.model.visual.get_device()) - image_grid_thw = image_inputs['image_grid_thw'].to( - self.model.visual.get_device()) - pixel_values = pixel_values.type(self.model.visual.get_dtype()) - image_embeds = self.model.visual(pixel_values, - grid_thw=image_grid_thw).cpu() - merge_length = self.processor.image_processor.merge_size**2 - split_size = image_inputs['image_grid_thw'].prod(dim=1) // merge_length - image_embeds = image_embeds.split(split_size.tolist()) + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + assert 0, 'TODO: support turbomind engine' - outputs = [] - for i, embeddings in enumerate(image_embeds): - outputs.append( - dict(embeddings=embeddings, - grid_thw=image_inputs['image_grid_thw'][i].tolist())) - return outputs + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: + continue + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + item['text'] for item in message['content'] + if item['type'] == 'text' + ] + prompt = content[0] + if IMAGE_TOKEN in prompt and '<|vision_start|>' not in prompt: + prompt = prompt.replace( + IMAGE_TOKEN, + f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>') + else: + # Qwen2-VL-2B-Instruct will concat image and user prompt + # according to their order in the content list + # we insert image token before user prompt by default. The + # user can use custom image token position if they want the + # same decorated prompt as Qwen2-VL + prompt = f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>' * \ + n_images + prompt + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + """return to the information needed by pytorch engine.""" + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/xcomposer2.py b/lmdeploy/vl/model/xcomposer2.py index 96bc900c02..3c72d0c29f 100644 --- a/lmdeploy/vl/model/xcomposer2.py +++ b/lmdeploy/vl/model/xcomposer2.py @@ -5,7 +5,7 @@ import sys import warnings from contextlib import contextmanager -from typing import Any, List, Tuple +from typing import Any, Dict, List, Tuple import torch from PIL.Image import Image @@ -19,6 +19,17 @@ logger = get_logger('lmdeploy') +def check_xcomposer_install(): + try: + # WARNING! we have to do this otherwise the model_type is wrong for + # xcomposer2d5 + import decord # noqa: F401 + except ImportError: + raise ImportError( + "No module named 'decord'. Please install decord by `pip install decord`" # noqa + ) + + class ModelType(enum.Enum): """Request type.""" XCOMPOSER2 = enum.auto() @@ -83,6 +94,17 @@ def init_empty_vit(model_path): class Xcomposer2VisionModel(VisonModel): """InternLM-Xcomposer2 vision model.""" + def __init__(self, + model_path: str, + with_llm: bool = False, + max_memory: Dict[int, int] = None, + hf_config: AutoConfig = None, + backend: str = ''): + super().__init__(model_path, with_llm, max_memory, hf_config, backend) + check_xcomposer_install() + self.model_type, self.module = get_xcomposer_type(self.model_path) + logger.info(f'matching type of {self.model_type}') + @classmethod def match(cls, config: AutoConfig): """check whether the config match the model.""" @@ -94,7 +116,37 @@ def match(cls, config: AutoConfig): return True return False + def build_preprocessor(self): + + import torchvision.transforms as transforms + from torchvision.transforms.functional import InterpolationMode + + if self.model_type in [ + ModelType.XCOMPOSER2D5, ModelType.XCOMPOSER2_4KHD + ]: + self.HD_transform = self.module + self.vis_processor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + self.preprocess_func = (self._preprocess_2d5 if self.model_type + == ModelType.XCOMPOSER2D5 else + self._preprocess_4khd_7b) + else: + self.vis_processor = transforms.Compose([ + transforms.Resize( + (self.hf_config.img_size, self.hf_config.img_size), + interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + self.preprocess_func = self._preprocess_7b + def build_model(self): + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" from accelerate import init_empty_weights with init_empty_weights(), warnings.catch_warnings(), \ init_empty_vit(self.model_path): @@ -106,23 +158,10 @@ def build_model(self): model.vit.resize_pos() model.vit.vision_tower.vision_model.post_layernorm.to_empty( device='cpu').half() + self.vl_model = model if not self.with_llm: del model.model del model.output - else: - self.vl_model = model - - # additional components. - model_type, module = get_xcomposer_type(self.model_path) - logger.info(f'matching type of {model_type}') - if model_type == ModelType.XCOMPOSER2D5: - self.HD_transform = module - self._forward_func = self._forward_2d5 - elif model_type == ModelType.XCOMPOSER2_4KHD: - self.HD_transform = module - self._forward_func = self._forward_4khd_7b - else: - self._forward_func = self._forward_7b from accelerate.utils import get_balanced_memory, infer_auto_device_map max_memory = get_balanced_memory( @@ -156,51 +195,117 @@ def build_model(self): self.model = model.eval() - def _forward_2d5(self, images: List[Image]) -> List[torch.Tensor]: - """internlm-xcomposer2d5-7b vit forward.""" - outputs = [x.convert('RGB') for x in images] - hd_num = 6 if len(images) > 1 else 24 - outputs = [self.HD_transform(x, hd_num=hd_num) for x in outputs] - outputs = [ - self.model.vis_processor(x).unsqueeze(0).to(dtype=torch.half) - for x in outputs - ] - embeds, split = self.model.vit(outputs, self.model.plora_glb_GN, - self.model.plora_sub_GN) - embeds = self.model.vision_proj(embeds) - embeds = torch.split(embeds, split, dim=1) - embeds = [x.squeeze() for x in embeds] - return embeds - - def _forward_7b(self, images: List[Image]) -> List[torch.Tensor]: - """internlm-xcomposer2-7b vit forward.""" - outputs = [x.convert('RGB') for x in images] - outputs = [ - self.model.vis_processor(x).unsqueeze(0).half() for x in outputs - ] - outputs = torch.cat(outputs, dim=0) - outputs = self.model.vit(outputs) - outputs = self.model.vision_proj(outputs) - outputs = torch.split(outputs, 1, dim=0) - outputs = [x.squeeze() for x in outputs] - return outputs - - def _forward_4khd_7b(self, images: List[Image]) -> List[torch.Tensor]: - """internlm-xcomposer2-4khd-7b vit forward.""" - outputs = [x.convert('RGB') for x in images] - outputs = [self.HD_transform(x, hd_num=25) for x in outputs] - outputs = [ - self.model.vis_processor(x).unsqueeze(0).to(dtype=torch.half) - for x in outputs - ] - embeds, split = self.model.vit(outputs, self.model.plora_glb_GN, - self.model.plora_sub_GN) - embeds = self.model.vision_proj(embeds) - embeds = torch.split(embeds, split, dim=1) - embeds = [x.squeeze() for x in embeds] - return embeds + def _preprocess_2d5(self, image: Image, params: Dict) -> Dict: + """image preprocessing for internlm-xcomposer2d5-7b.""" + hd_num = params.get('hd_num', 24) + image = self.HD_transform(image, hd_num=hd_num) + pixel_values = self.vis_processor(image).unsqueeze(0).half() + w, h = image.size + n_token_per_image = int((h * w + 1) * 400 + 1 + (h + 1) * 20) + return pixel_values, n_token_per_image + + def _preprocess_7b(self, image: Image, params: Dict) -> Dict: + """image preprocessing for internlm-xcomposer2-7b.""" + pixel_values = self.vis_processor(image).unsqueeze(0).half() + return pixel_values, 256 + + def _preprocess_4khd_7b(self, image: Image, params: Dict) -> Dict: + """image preprocessing for internlm-xcomposer2-4khd-7b.""" + image = self.HD_transform(image, hd_num=25) + pixel_values = self.vis_processor(image).unsqueeze(0).half() + w, h = image.size + n_token_per_image = int((h * w + 1) * 144 + 1 + (h + 1) * 12) + return pixel_values, n_token_per_image + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to `super().preprocess() for spec.""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + pixel_values, n_token = self.preprocess_func(image, params) + outputs.append( + dict(pixel_values=pixel_values, + image_size=image.size, + image_tokens=n_token, + image_token_id=0)) + messages.append(dict(role='preprocess', content=outputs)) + return messages @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - return self._forward_func(images) + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + if self.model_type in [ + ModelType.XCOMPOSER2D5, ModelType.XCOMPOSER2_4KHD + ]: + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + embeds, split = self.model.vit(pixel_values, + self.model.plora_glb_GN, + self.model.plora_sub_GN) + embeds = self.model.vision_proj(embeds) + embeds = torch.split(embeds, split, dim=1) + embeds = [x.squeeze() for x in embeds] + else: + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.cat(pixel_values, dim=0) + logger.info(f'vision forward shape: {pixel_values.shape}') + embeds = self.model.vit(pixel_values) + embeds = self.model.vision_proj(embeds) + embeds = torch.split(embeds, 1, dim=0) + embeds = [x.squeeze() for x in embeds] + outputs.extend(embeds) + messages.append(dict(role='forward', content=outputs)) + return messages + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: + continue + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + item['text'] for item in message['content'] + if item['type'] == 'text' + ] + prompt = ' '.join([IMAGE_TOKEN] * n_images) + content[0] + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/yi.py b/lmdeploy/vl/model/yi.py index 34b993322e..f8d3a907ff 100644 --- a/lmdeploy/vl/model/yi.py +++ b/lmdeploy/vl/model/yi.py @@ -1,12 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import os from contextlib import contextmanager +from os import path as osp +from typing import Dict, List import torch.nn as nn from transformers import AutoConfig from lmdeploy.vl.model.base import VISION_MODELS -from lmdeploy.vl.model.llava import LlavaVisionModel, check_llava_install +from lmdeploy.vl.model.llava import (LlavaVisionModel, check_llava_install, + process_images) from .utils import disable_transformers_logging, rewrite_ctx @@ -96,8 +99,22 @@ def match(cls, config: AutoConfig): return True return False + def build_preprocessor(self): + from transformers import CLIPImageProcessor + vision_tower_name = osp.join(self.model_path, + self.hf_config.mm_vision_tower) + self.image_processor = CLIPImageProcessor.from_pretrained( + vision_tower_name) + config = AutoConfig.from_pretrained(vision_tower_name) + image_size = config.image_size + patch_size = config.patch_size + self.n_token_per_image = (image_size // patch_size)**2 + if self.hf_config.mm_vision_select_feature == 'cls_patch': + self.n_token_per_image += 1 + def build_model(self): - """build model & load weights.""" + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" check_llava_install() global _model_path @@ -105,3 +122,19 @@ def build_model(self): with init_yi_model(), disable_transformers_logging(): super().build_model() + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to `super().preprocess() for spec.""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + pixel_values = process_images([image], self.image_processor, + self.config) + outputs.append( + dict(pixel_values=pixel_values, + image_size=image.size, + image_tokens=self.n_token_per_image, + image_token_id=0)) + messages.append(dict(role='preprocess', content=outputs)) + return messages diff --git a/lmdeploy/vl/templates.py b/lmdeploy/vl/templates.py deleted file mode 100644 index cdf398868a..0000000000 --- a/lmdeploy/vl/templates.py +++ /dev/null @@ -1,550 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import asyncio -from typing import Dict, List, Tuple, Union - -import PIL -import PIL.Image - -from lmdeploy.archs import get_model_arch -from lmdeploy.model import BaseModel -from lmdeploy.utils import get_logger -from lmdeploy.vl.constants import IMAGE_TOKEN -from lmdeploy.vl.utils import load_image - -logger = get_logger('lmdeploy') - -VLPromptType = Union[str, Tuple[str, PIL.Image.Image], - Tuple[str, List[PIL.Image.Image]]] - - -class VLChatTemplateWrapper: - """vl chat template wrapper.""" - - def __init__(self, chat_template: BaseModel): - self.chat_template = chat_template - - def prompt_to_messages(self, prompt: VLPromptType): - """convert prompt to GTP4V format.""" - messages = { - 'role': 'user', - 'content': [{ - 'type': 'text', - 'text': '', - }] - } - if isinstance(prompt, str): - messages['content'][0]['text'] = prompt - else: - prompt, images = prompt - if not isinstance(images, list): - images = [images] - messages['content'][0]['text'] = prompt - for image in images: - # 'image_url': means url or local path to image. - # 'image_data': means PIL.Image.Image object. - if isinstance(image, str): - image = load_image(image) - item = { - 'type': 'image_data', - 'image_data': { - 'data': image - } - } - elif isinstance(image, PIL.Image.Image): - item = { - 'type': 'image_data', - 'image_data': { - 'data': image - } - } - else: - raise ValueError( - 'image should be a str(url/path) or PIL.Image.Image') - - messages['content'].append(item) - - return [messages] - - async def async_collect_pil_images( - self, messages: Dict) -> List[Tuple[PIL.Image.Image, Dict]]: - """collect image from messages.""" - images_with_kwargs = [] - for message in messages: - role = message['role'] - content = message['content'] - if role != 'user' or isinstance(content, str): - continue - for item in content: - # 'image_url': means url or local path to image. - # 'image_data': means PIL.Image.Image object. - if item['type'] == 'image_url': - item_copy = item['image_url'].copy() - try: - url = item_copy.pop('url') - images_with_kwargs.append([url, item_copy]) - except KeyError: - logger.error(f'invalid format {message}') - elif item['type'] == 'image_data': - item_copy = item['image_data'].copy() - try: - data = item_copy.pop('data') - images_with_kwargs.append([data, item_copy]) - except KeyError: - logger.error(f'invalid format {message}') - - def _inner_call(i, images): - url_or_data = images[i][0] - images[i][0] = load_image(url_or_data) - - await asyncio.gather(*[ - asyncio.get_event_loop().run_in_executor(None, _inner_call, i, - images_with_kwargs) - for i in range(len(images_with_kwargs)) - ]) - - return images_with_kwargs - - def append_image_token(self, prompt, num_images: int): - """append image token to user prompt.""" - if IMAGE_TOKEN in prompt: - return prompt - return (IMAGE_TOKEN + '\n') * num_images + prompt - - def convert_messages(self, messages, sequence_start=True): - """convert GPT4V message format to GPT4 text format.""" - new_messages = [] - for message in messages: - role = message['role'] - content = message['content'] - if role != 'user' or isinstance(content, str): - if isinstance(content, list): - text = content[0]['text'] - message = {'role': role, 'content': text} - new_messages.append(message) - continue - num_images = 0 - for item in content: - # 'image_url': means url or local path to image. - # 'image_data': means PIL.Image.Image object. - if item['type'] == 'image_url': - num_images += 1 - elif item['type'] == 'image_data': - num_images += 1 - elif item['type'] == 'text': - prompt = item['text'] - if num_images > 0: - # add IMAGE_TOKEN to user prompt - prompt = self.append_image_token(prompt, num_images) - new_item = {'role': 'user', 'content': prompt} - new_messages.append(new_item) - return new_messages - - def messages2prompt(self, messages, sequence_start=True, **kwargs) -> str: - """convert messages to decorated prompt.""" - if isinstance(messages, str): - return self.chat_template.messages2prompt(messages, sequence_start) - new_messages = self.convert_messages(messages, sequence_start) - return self.chat_template.messages2prompt(new_messages, sequence_start) - - -class LlavaVLChatTemplateWrapper(VLChatTemplateWrapper): - """Llava vl chat template.""" - pass - - -class YiVLChatTemplateWrapper(VLChatTemplateWrapper): - """Yi vl chat template.""" - pass - - -class InternVLChatTemplateWrapper(VLChatTemplateWrapper): - """InternVL chat template.""" - - def append_image_token(self, prompt, num_images: int): - """append image tokens to user prompt.""" - # lmdeploy uses as image token - # internvl uses special tags - if IMAGE_TOKEN in prompt and f'{IMAGE_TOKEN}' not in prompt: - prompt = prompt.replace(f'{IMAGE_TOKEN}', - f'{IMAGE_TOKEN}') - prompt = prompt.replace('', '') - prompt = prompt.replace('', '') - prompt = prompt.replace('', '') - elif IMAGE_TOKEN not in prompt: - prompt = f'{IMAGE_TOKEN * num_images}\n' + prompt - return prompt - - -class DeepSeekVLChatTemplateWrapper(VLChatTemplateWrapper): - """DeepSeek vl chat template.""" - - def append_image_token(self, prompt, num_images: int): - """append image tokens to user prompt.""" - if IMAGE_TOKEN in prompt: - return prompt - logger.error( - f'for deepseek-vl model, the user should insert the {IMAGE_TOKEN} ' - 'to user prompt manually, please read https://lmdeploy.readthedocs' - '.io/en/latest/inference/vl_pipeline.html for more details.') - if num_images == 1: - return f'{IMAGE_TOKEN}{prompt}' - res = '' - for i in range(num_images): - res += f'{IMAGE_TOKEN} is Figure {str(i)}.\n' - res = res + prompt - return res - - -class QwenVLChatTemplateWrapper(VLChatTemplateWrapper): - """Qwen vl chat template.""" - - def append_image_token(self, prompt, num_images: int): - """append image tokens to user prompt.""" - if IMAGE_TOKEN in prompt: - return prompt - res = '' - for i in range(num_images): - res += f'Picture {str(i)}:{IMAGE_TOKEN}\n' - res = res + prompt - return res - - -class Qwen2VLChatTemplateWrapper(VLChatTemplateWrapper): - """qwen2 vl.""" - - def append_image_token(self, prompt, num_images: int): - """append image tokens to user prompt.""" - if IMAGE_TOKEN in prompt and '<|vision_start|>' not in prompt: - prompt = prompt.replace( - IMAGE_TOKEN, f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>') - else: - # Qwen2-VL-2B-Instruct will concat image and user prompt according - # to their order in the content list - # we insert image token before user prompt by default. The user can - # use custom image token position if they want the same decorated - # prompt as Qwen2-VL - prompt = f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>' * \ - num_images + prompt - return prompt - - def get_mrope_info(self, - seq_len: int, - grid_thws: List[Tuple[int, int, int]] = None, - embedding_ranges: List[Tuple[int, int]] = None): - import torch - if grid_thws is None: - mrope_position_ids = torch.arange(seq_len).expand(3, -1) - mrope_position_delta = torch.tensor([0], dtype=torch.long) - else: - mrope_position_ids = [ - torch.arange(embedding_ranges[0][0]).expand(3, -1) - ] - st_idx = embedding_ranges[0][0] - for i, (grid_thw, embedding_range) in enumerate( - zip(grid_thws, embedding_ranges)): - llm_grid_t, llm_grid_h, llm_grid_w = grid_thw - llm_grid_h //= 2 - llm_grid_w //= 2 - t_index = torch.arange(llm_grid_t).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( - llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( - llm_grid_t, llm_grid_h, -1).flatten() - mrope_position_ids.append( - torch.stack([t_index, h_index, w_index]) + st_idx) - st_idx += max(llm_grid_h, llm_grid_w) - if i < len(embedding_ranges) - 1: - text_len = embedding_ranges[i + - 1][0] - embedding_ranges[i][1] - else: - text_len = seq_len - embedding_range[1] - mrope_position_ids.append( - torch.arange(text_len).expand(3, -1) + st_idx) - st_idx += text_len - mrope_position_ids = torch.cat(mrope_position_ids, dim=-1) - mrope_position_delta = torch.tensor([st_idx - seq_len], - dtype=torch.long) - - return mrope_position_ids, mrope_position_delta - - -class CogVLMChatTemplateWrapper(VLChatTemplateWrapper): - """cogvlm chat template wrapper.""" - - def __init__(self, chat_template: BaseModel): - from lmdeploy.model import Vicuna - self.chat_template = chat_template - self.llm_chat_template = Vicuna(eoa=chat_template.eoa, - stop_words=chat_template.stop_words) - - def convert_messages(self, messages, sequence_start=True): - """convert GPT4V message format to GPT4 text format.""" - new_messages = [] - for message in messages: - role = message['role'] - content = message['content'] - if role != 'user' or isinstance(content, str): - new_messages.append(message) - continue - num_images = 0 - for item in content: - if item['type'] == 'image_url': - num_images += 1 - elif item['type'] == 'image_data': - num_images += 1 - elif item['type'] == 'text': - prompt = item['text'] - - new_item = { - 'role': 'user', - 'content': prompt, - 'num_images': num_images - } - new_messages.append(new_item) - return new_messages - - def messages2prompt(self, messages, sequence_start=True, **kwargs) -> str: - """convert messages to decorated prompt.""" - if isinstance(messages, str): - return self.chat_template.messages2prompt(messages, sequence_start) - new_messages = self.convert_messages(messages, sequence_start) - prompt = '' - for i, msg in enumerate(new_messages): - num_images = msg.pop('num_images', 0) - if num_images == 0: - role = msg['role'] - msg = self.llm_chat_template.messages2prompt([msg], - sequence_start - and i == 0) - msg = dict(role=role, content=msg) - prompt_i = self.chat_template.messages2prompt([msg], sequence_start - and i == 0) - if num_images > 0: - prompt_i = (IMAGE_TOKEN * num_images) + prompt_i - prompt += prompt_i - return prompt - - -class InternLMXComposer2TemplateWrapper(VLChatTemplateWrapper): - """InternLM-XComposer2 chat template.""" - - def append_image_token(self, prompt, num_images: int): - if IMAGE_TOKEN in prompt: - return prompt - logger.warning(f'auto append {IMAGE_TOKEN} at the beginning, ' - 'the user can manually insert the token to prompt') - return ' '.join([IMAGE_TOKEN] * num_images) + prompt - - -class MiniGeminiLlamaTempateWrapper(VLChatTemplateWrapper): - """Qwen vl chat template.""" - - def append_image_token(self, prompt, num_images: int): - """append image tokens to user prompt.""" - if num_images == 0: - return prompt - if IMAGE_TOKEN in prompt: - return prompt - res = f'{IMAGE_TOKEN}\n' - assert num_images <= 1, 'MiniGeminiLlama accepts 1 input image' - res = res + prompt - return res - - -class MllamaTempateWrapper(VLChatTemplateWrapper): - """Mllama chat template.""" - - def append_image_token(self, prompt, num_images: int): - """append image tokens to user prompt.""" - return f'{IMAGE_TOKEN * num_images}{prompt}' - - -class MiniCPMVTempateWrapper(VLChatTemplateWrapper): - """MiniCPM-Llama3-V-2_5 chat template.""" - - def append_image_token(self, prompt, num_images: int): - if IMAGE_TOKEN in prompt: - return prompt - prompt = f'{IMAGE_TOKEN}\n' * num_images + prompt - return prompt - - def update_image_token(self, prompt, features): - _features = [] - _prompt = [] - segs = prompt.split(f'{IMAGE_TOKEN}\n') - for i, seg in enumerate(segs): - if i > 0 and i <= len(features): - _feat = features[i - 1]['embeddings'].split(1) - _feat = [x.squeeze() for x in _feat] - _features.extend(_feat) - _seg = f'{IMAGE_TOKEN}' - if len(_feat) > 1: - grid = features[i - 1]['grid'] - if grid is not None: - _slice = '\n'.join( - [f'{IMAGE_TOKEN}' * grid[0]] * - grid[1]) - _seg = f'{_seg}{_slice}\n' - _prompt.append(_seg) - _prompt.append(seg) - _prompt = ''.join(_prompt) - return _prompt, _features - - -class MiniCPMV26TempateWrapper(MiniCPMVTempateWrapper): - """MiniCPM-V-2_6 chat template.""" - - def update_image_token(self, prompt, features): - _features = [] - _prompt = [] - segs = prompt.split(f'{IMAGE_TOKEN}\n') - idx = 0 - for i, seg in enumerate(segs): - if i > 0 and i <= len(features): - _feat = features[i - 1]['embeddings'].split(1) - _feat = [x.squeeze() for x in _feat] - _features.extend(_feat) - _seg = f'{IMAGE_TOKEN}' - if features[i - 1].get('use_image_id', False): - _seg = f'{idx}' + _seg - idx += 1 - if len(_feat) > 1: - grid = features[i - 1]['grid'] - if grid is not None: - _slice = '\n'.join( - [f'{IMAGE_TOKEN}' * grid[0]] * - grid[1]) - _seg = _seg + _slice - _seg += '\n' - _prompt.append(_seg) - _prompt.append(seg) - _prompt = ''.join(_prompt) - return _prompt, _features - - -class GLM4VChatTemplateWrapper(VLChatTemplateWrapper): - """glm-4v chat template.""" - pass - - -class MolmoChatTemplateWrapper(VLChatTemplateWrapper): - - async def async_collect_pil_images( - self, messages: List[Dict]) -> List[Tuple[PIL.Image.Image, Dict]]: - """collect images from messages. - - Args: - messages (List[Dict]): a user request of GPT4V message format - """ - if isinstance(messages, Dict): - messages = [messages] - assert isinstance(messages, List) - - out_messages = [None] * len(messages) - - def _inner_call(i, in_messages, out_messages): - role = in_messages[i]['role'] - content = in_messages[i]['content'] - if role != 'user' or isinstance(content, str): - # means message is user's prompt input or assistant's prompt, - # returning it directory - out_messages[i] = in_messages[i] - return - # the role is a user and the content is a list - assert isinstance(content, List) - message = dict(role=role, content='', images=[]) - for item in content: - # 'image_url': means url or local path to image. - # 'image_data': means PIL.Image.Image object. - if item['type'] == 'image_url': - try: - image = load_image(item['image_url']['url']) - message['images'].append(image) - except KeyError: - logger.error(f'invalid format {message}') - elif item['type'] == 'image_data': - try: - image = load_image(item['image_data']['data']) - message['images'].append(image) - except KeyError: - logger.error(f'invalid format {message}') - elif item['type'] == 'text': - message['content'] = item['text'] - else: - logger.error(f'unexpected content type {message}') - out_messages[i] = message - - await asyncio.gather(*[ - asyncio.get_event_loop().run_in_executor(None, _inner_call, i, - messages, out_messages) - for i in range(len(messages)) - ]) - return [(None, out_messages)] - - def messages2prompt(self, messages, sequence_start=True, **kwargs) -> str: - """Return a placeholder "IMAGE_TOKEN" so that - `vl_asyn_engine._get_prompt_input` can know that it.""" - if isinstance(messages, str): - return self.chat_template.messages2prompt(messages, sequence_start) - else: - _messages = [] - for message in messages: - role, content = message['role'], message['content'] - if role != 'user' or isinstance(content, str): - _messages.append(message) - continue - for item in content: - item_type = item['type'] - if item_type in ['image_url', 'image_data']: - # Return the image placeholder so that - # `vl_asyn_engine._get_prompt_input` can know that the - # request contains images - return IMAGE_TOKEN - _messages.append(dict(role=role, content=item[item_type])) - return self.chat_template.messages2prompt(_messages, - sequence_start) - - -def get_vl_prompt_template(model_path: str, chat_template: BaseModel, - model_name: str) -> VLChatTemplateWrapper: - """get vision language prompt template.""" - assert type(chat_template) != type(BaseModel()), 'failed to match ' \ - 'chat template, please explicit set chat_template_config' # noqa E721 - if model_name == 'yi-vl': - return YiVLChatTemplateWrapper(chat_template) - arch, cfg = get_model_arch(model_path) - if arch == 'QWenLMHeadModel': - return QwenVLChatTemplateWrapper(chat_template) - elif arch in [ - 'LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM', - 'LlavaForConditionalGeneration', - 'LlavaNextForConditionalGeneration', 'Phi3VForCausalLM' - ]: - return LlavaVLChatTemplateWrapper(chat_template) - elif arch == 'MultiModalityCausalLM': # deepseek-vl - return DeepSeekVLChatTemplateWrapper(chat_template) - elif arch == 'MllamaForConditionalGeneration': # llama 3.2 - return MllamaTempateWrapper(chat_template) - elif arch == 'CogVLMForCausalLM': - return CogVLMChatTemplateWrapper(chat_template) - elif arch in ['InternLMXComposer2ForCausalLM', 'InternLM2ForCausalLM']: - return InternLMXComposer2TemplateWrapper(chat_template) - elif arch == 'InternVLChatModel': - return InternVLChatTemplateWrapper(chat_template) - elif arch in ['MiniGeminiLlamaForCausalLM', 'MGMLlamaForCausalLM']: - return MiniGeminiLlamaTempateWrapper(chat_template) - elif arch == 'MiniCPMV': - version_map = { - '2.5': MiniCPMVTempateWrapper, - '2.6': MiniCPMV26TempateWrapper - } - version = str(getattr(cfg, 'version', '2.5')) - return version_map[version](chat_template) - elif arch == 'ChatGLMModel': - return GLM4VChatTemplateWrapper(chat_template) - elif arch == 'Qwen2VLForConditionalGeneration': - return Qwen2VLChatTemplateWrapper(chat_template) - elif arch == 'MolmoForCausalLM': - return MolmoChatTemplateWrapper(chat_template) - raise ValueError(f'unsupported vl_prompt_template with arch {arch}') diff --git a/requirements/runtime_ascend.txt b/requirements/runtime_ascend.txt index 05d74bbe72..81f538275c 100644 --- a/requirements/runtime_ascend.txt +++ b/requirements/runtime_ascend.txt @@ -1,5 +1,5 @@ accelerate>=0.29.3 -dlinfer-ascend>=0.1.2 +dlinfer-ascend>=0.1.3 einops fastapi fire @@ -16,7 +16,8 @@ safetensors sentencepiece shortuuid tiktoken -torch<=2.4.0,>=2.0.0 -torchvision<=0.19.0,>=0.15.0 +torch<=2.4.0,>=2.3.1 +torch-npu==2.3.1 +torchvision<=0.19.0,>=0.18.1 transformers uvicorn diff --git a/requirements/runtime.txt b/requirements/runtime_cuda.txt similarity index 82% rename from requirements/runtime.txt rename to requirements/runtime_cuda.txt index ec4957608c..41af6039bd 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime_cuda.txt @@ -15,8 +15,8 @@ safetensors sentencepiece shortuuid tiktoken -torch<=2.4.0,>=2.0.0 -torchvision<=0.19.0,>=0.15.0 +torch<=2.5.1,>=2.0.0 +torchvision<=0.20.1,>=0.15.0 transformers triton==3.0.0; sys_platform == "linux" uvicorn diff --git a/requirements/runtime_maca.txt b/requirements/runtime_maca.txt new file mode 100644 index 0000000000..f65b3827cd --- /dev/null +++ b/requirements/runtime_maca.txt @@ -0,0 +1,22 @@ +accelerate==0.32.1 +einops +fastapi +fire +mmengine-lite +numpy<2.0.0 +openai +outlines<0.1.0 +peft<=0.11.1 +pillow +protobuf +pydantic>2.0.0 +pynvml +safetensors +sentencepiece +shortuuid +tiktoken +torch<=2.4.0,>=2.0.0 +torchvision<=0.19.0,>=0.15.0 +transformers +triton>=2.1.0; sys_platform == "linux" +uvicorn diff --git a/requirements.txt b/requirements_cuda.txt similarity index 70% rename from requirements.txt rename to requirements_cuda.txt index 91d38808f1..7c1d387dfb 100644 --- a/requirements.txt +++ b/requirements_cuda.txt @@ -1,4 +1,4 @@ -r requirements/build.txt --r requirements/runtime.txt +-r requirements/runtime_cuda.txt -r requirements/lite.txt -r requirements/serve.txt diff --git a/requirements_maca.txt b/requirements_maca.txt new file mode 100644 index 0000000000..075b132c8c --- /dev/null +++ b/requirements_maca.txt @@ -0,0 +1,4 @@ +-r requirements/build.txt +-r requirements/runtime_maca.txt +-r requirements/lite.txt +-r requirements/serve.txt diff --git a/setup.py b/setup.py index 7a08ac7919..52e180d8a2 100644 --- a/setup.py +++ b/setup.py @@ -4,18 +4,14 @@ from setuptools import find_packages, setup -npu_available = False -try: - import torch_npu - - npu_available = torch_npu.npu.is_available() -except ImportError: - pass - pwd = os.path.dirname(__file__) version_file = 'lmdeploy/version.py' +def get_target_device(): + return os.getenv('LMDEPLOY_TARGET_DEVICE', 'cuda') + + def readme(): with open(os.path.join(pwd, 'README.md'), encoding='utf-8') as f: content = f.read() @@ -154,16 +150,12 @@ def gen_packages_items(): setup_requires=parse_requirements('requirements/build.txt'), tests_require=parse_requirements('requirements/test.txt'), install_requires=parse_requirements( - 'requirements/runtime_ascend.txt' - if npu_available else 'requirements/runtime.txt'), + f'requirements/runtime_{get_target_device()}.txt'), extras_require={ 'all': - parse_requirements('requirements_ascend.txt' - if npu_available else 'requirements.txt'), - 'lite': - parse_requirements('requirements/lite.txt'), - 'serve': - parse_requirements('requirements/serve.txt') + parse_requirements(f'requirements_{get_target_device()}.txt'), + 'lite': parse_requirements('requirements/lite.txt'), + 'serve': parse_requirements('requirements/serve.txt') }, has_ext_modules=check_ext_modules, classifiers=[ diff --git a/src/turbomind/kernels/gemm/moe_utils_v2.cu b/src/turbomind/kernels/gemm/moe_utils_v2.cu index a9e4f7da51..44fec67748 100644 --- a/src/turbomind/kernels/gemm/moe_utils_v2.cu +++ b/src/turbomind/kernels/gemm/moe_utils_v2.cu @@ -2,6 +2,7 @@ #include #include +#include #include #include #include diff --git a/src/turbomind/kernels/gemm/test/test_utils.cu b/src/turbomind/kernels/gemm/test/test_utils.cu index 8f2b4007f6..8ee595ab9b 100644 --- a/src/turbomind/kernels/gemm/test/test_utils.cu +++ b/src/turbomind/kernels/gemm/test/test_utils.cu @@ -84,7 +84,7 @@ FastCompare(const T* src, const T* ref, int dims, int bsz, cudaStream_t stream, thrust::cuda::par.on(stream), zip_iter, zip_iter + count, - [=] __device__(auto tup) { + [=] __host__ __device__(thrust::tuple tup) -> Tuple { float s = thrust::get<0>(tup); float r = thrust::get<1>(tup); float abs_diff = fabsf(s - r); diff --git a/tests/pytorch/engine/test_request.py b/tests/pytorch/engine/test_request.py index 813a30e8e7..68ef6b9db9 100644 --- a/tests/pytorch/engine/test_request.py +++ b/tests/pytorch/engine/test_request.py @@ -3,7 +3,7 @@ import pytest from lmdeploy.pytorch.engine.request import (RequestManager, RequestType, - Response, ResponseType) + ResponseType) class TestRequestHander: @@ -17,36 +17,31 @@ def event_loop(self): asyncio.set_event_loop(old_loop) @pytest.fixture - def thread_safe(self, request): - yield request.param + def manager(self): + yield RequestManager() - @pytest.fixture - def manager(self, thread_safe): - yield RequestManager(thread_safe=thread_safe) - - @pytest.mark.parametrize('thread_safe', [True, False]) def test_bind(self, manager, event_loop): def __stop_engine_callback(reqs, **kwargs): for req in reqs: - manager.response( - Response(type=ResponseType.SUCCESS, - sender_id=req.sender_id, - req_id=req.req_id, - data=f'{req.data} success')) + resp = req.resp + resp.type = ResponseType.SUCCESS + resp.data = f'{req.data} success' + manager.response(resp) async def __dummy_loop(): while True: - manager.step() - await asyncio.sleep(0.1) + try: + await manager.step() + except Exception: + return - asyncio.set_event_loop(event_loop) sender = manager.build_sender() manager.start_loop(__dummy_loop) # test not bind - req_id = sender.send_async(RequestType.STOP_ENGINE, None) - resp = sender.recv(req_id) + resp = sender.send_async(RequestType.STOP_ENGINE, None) + resp = sender.recv(resp) assert resp.type == ResponseType.HANDLER_NOT_EXIST assert manager.is_loop_alive() @@ -54,6 +49,8 @@ async def __dummy_loop(): # test bind success sender.send_async(RequestType.STOP_ENGINE, None) manager.bind_func(RequestType.STOP_ENGINE, __stop_engine_callback) - req_id = sender.send_async(RequestType.STOP_ENGINE, 'test') - resp = sender.recv(req_id) + resp = sender.send_async(RequestType.STOP_ENGINE, 'test') + resp = sender.recv(resp) assert resp.data == 'test success' + + manager.stop_loop() diff --git a/tests/pytorch/kernel/test_fuse_moe_blocked_fp8.py b/tests/pytorch/kernel/test_fuse_moe_blocked_fp8.py new file mode 100644 index 0000000000..bb165658dd --- /dev/null +++ b/tests/pytorch/kernel/test_fuse_moe_blocked_fp8.py @@ -0,0 +1,231 @@ +import pytest +import torch + + +def _make_A(M, K, group_size, out_dtype, device='cuda'): + quant_A = torch.rand(M, + K // group_size, + group_size, + dtype=torch.float32, + device=device) + # -1 ~ 1 + quant_A = quant_A * 2 - 1 + # scaling abs max to fmax + finfo = torch.finfo(out_dtype) + fmax = finfo.max + scaling = fmax / quant_A.abs().amax(-1, keepdim=True) + quant_A *= scaling + quant_A = quant_A.to(out_dtype).to(torch.float32) + + # create scale and A + scale = torch.rand(M, K // group_size, dtype=torch.float32, device=device) + scale /= fmax + A = quant_A * scale[..., None] + + A = A.reshape(M, K) + quant_A = quant_A.reshape(M, K).to(out_dtype) + return A, quant_A, scale + + +def _make_B(E, K, N, group_size, out_dtype, device='cuda'): + quant_B = torch.rand(E, + N // group_size, + group_size, + K // group_size, + group_size, + dtype=torch.float32, + device=device) + quant_B = quant_B * 2 - 1 + + # scaling abs max to fmax + finfo = torch.finfo(out_dtype) + fmax = finfo.max + scaling = fmax / quant_B.abs().amax((2, 4), keepdim=True) + quant_B *= scaling + quant_B = quant_B.to(out_dtype).to(torch.float32) + + scale = torch.rand(E, + N // group_size, + 1, + K // group_size, + 1, + dtype=torch.float32, + device=device) + scale /= fmax + + B = quant_B * scale + + B = B.reshape(E, N, K) + quant_B = quant_B.reshape(E, N, K).to(out_dtype) + scale = scale.reshape(E, N // group_size, K // group_size) + return B, quant_B, scale + + +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, + reason='require device with cc>=9.0') +class TestFusedMoeBlockedFP8: + + @pytest.fixture + def dtype(self): + yield torch.float16 + + @pytest.fixture + def quant_dtype(self): + yield torch.float8_e4m3fn + + @pytest.fixture + def device(self): + yield torch.device('cuda') + + @pytest.fixture + def in_size(self): + yield 512 + + @pytest.fixture + def seq_len(seq_len): + yield 128 + + @pytest.fixture + def hidden_size(self): + yield 2048 + + @pytest.fixture + def out_size(self): + yield 1024 + + @pytest.fixture + def num_experts(self): + yield 4 + + @pytest.fixture + def top_k(self): + yield 2 + + @pytest.fixture + def group_size(self): + yield 128 + + @pytest.fixture + def renormalize(self): + yield True + + @pytest.fixture + def build_hidden_states(self, seq_len, in_size, group_size, quant_dtype, + device): + yield _make_A(seq_len, + in_size, + group_size=group_size, + out_dtype=quant_dtype, + device=device) + + @pytest.fixture + def hidden_states(self, build_hidden_states, dtype): + yield build_hidden_states[0].to(dtype) + + @pytest.fixture + def states_quanted(self, build_hidden_states): + yield build_hidden_states[1] + + @pytest.fixture + def states_scale(self, build_hidden_states): + yield build_hidden_states[2] + + @pytest.fixture + def build_w1(self, num_experts, hidden_size, in_size, group_size, + quant_dtype, device): + yield _make_B(num_experts, + in_size, + hidden_size, + group_size=group_size, + out_dtype=quant_dtype, + device=device) + + @pytest.fixture + def w1(self, build_w1, dtype): + yield build_w1[0].to(dtype) + + @pytest.fixture + def w1_quant(self, build_w1): + yield build_w1[1] + + @pytest.fixture + def w1_scale(self, build_w1): + yield build_w1[2] + + @pytest.fixture + def build_w2(self, num_experts, out_size, hidden_size, group_size, + quant_dtype, device): + yield _make_B(num_experts, + hidden_size // 2, + out_size, + group_size=group_size, + out_dtype=quant_dtype, + device=device) + + @pytest.fixture + def w2(self, build_w2, dtype): + yield build_w2[0].to(dtype) + + @pytest.fixture + def w2_quant(self, build_w2): + yield build_w2[1] + + @pytest.fixture + def w2_scale(self, build_w2): + yield build_w2[2] + + @pytest.fixture + def router_logits(self, seq_len, num_experts, dtype, device): + yield torch.rand(seq_len, num_experts, dtype=dtype, device=device) + + @pytest.fixture + def topk_logits(self, router_logits, top_k): + routing_weights = torch.softmax(router_logits, + dim=-1, + dtype=torch.float32) + yield torch.topk(routing_weights, top_k, dim=-1) + + @pytest.fixture + def topk_weights(self, topk_logits): + yield topk_logits[0] + + @pytest.fixture + def topk_idx(self, topk_logits): + yield topk_logits[1] + + @pytest.fixture + def gt(self, hidden_states, w1, w2, topk_weights, topk_idx, top_k, + renormalize): + from lmdeploy.pytorch.kernels.cuda.fused_moe import fused_moe + output = fused_moe(hidden_states, + w1, + w2, + topk_weights, + topk_idx, + topk=top_k, + renormalize=renormalize) + yield output + + @torch.inference_mode() + def test_fused_moe(self, states_quanted, states_scale, w1_quant, w1_scale, + w2_quant, w2_scale, topk_weights, topk_idx, top_k, + renormalize, gt): + from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import \ + fused_moe_blocked_fp8 + output = fused_moe_blocked_fp8(states_quanted, + states_scale, + w1_quant, + w1_scale, + w2_quant, + w2_scale, + topk_weights, + topk_idx, + topk=top_k, + renormalize=renormalize) + out_max = output.abs().max() + gt_max = gt.abs().max() + assert (out_max - gt_max).abs() / out_max < 0.05 + + norm_out = output / out_max + norm_gt = gt / gt_max + torch.testing.assert_close(norm_out, norm_gt, atol=0.05, rtol=1e-3) diff --git a/tests/pytorch/kernel/test_fused_moe.py b/tests/pytorch/kernel/test_fused_moe.py index 55e3a75c08..cc309eb6a7 100644 --- a/tests/pytorch/kernel/test_fused_moe.py +++ b/tests/pytorch/kernel/test_fused_moe.py @@ -250,3 +250,54 @@ def test_fused_moe(self, hidden_states, w1, w2, topk_weights, topk_idx, topk=top_k, renormalize=renormalize) torch.testing.assert_close(output, gt, atol=1e-3, rtol=1e-3) + + +class TestFusedMoeW8A8(TestFusedMoe): + + @pytest.fixture + def quant_states(self, hidden_states): + from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import \ + per_token_quant_int8 + states_i8, states_scale = per_token_quant_int8(hidden_states, 1e-7) + yield states_i8, states_scale + + def quant_weight(self, w): + from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import \ + per_channel_quant + num_experts, num_outs, _ = w.shape + w = w.flatten(0, -2) + w_i8, w_scale = per_channel_quant(w, torch.int8) + w_i8 = w_i8.view(num_experts, num_outs, -1) + w_scale = w_scale.view(num_experts, num_outs, -1) + return w_i8, w_scale + + @pytest.fixture + def quant_w1(self, w1): + w_i8, w_scale = self.quant_weight(w1) + yield w_i8, w_scale + + @pytest.fixture + def quant_w2(self, w2): + w_i8, w_scale = self.quant_weight(w2) + yield w_i8, w_scale + + @torch.inference_mode() + def test_fused_moe(self, quant_states, quant_w1, quant_w2, topk_weights, + topk_idx, top_k, renormalize, gt): + from lmdeploy.pytorch.kernels.cuda.w8a8_fused_moe import fused_moe_w8a8 + state_i8, state_scale = quant_states + w1_i8, w1_scale = quant_w1 + w2_i8, w2_scale = quant_w2 + + output = fused_moe_w8a8(state_i8, + state_scale, + w1_i8, + w1_scale, + w2_i8, + w2_scale, + topk_weights=topk_weights, + topk_ids=topk_idx, + topk=top_k, + out_dtype=torch.float16, + renormalize=renormalize) + torch.testing.assert_close(output, gt, atol=5e-3, rtol=1e-3) diff --git a/tests/pytorch/kernel/test_gemm_fp8.py b/tests/pytorch/kernel/test_gemm_fp8.py new file mode 100644 index 0000000000..242a2db581 --- /dev/null +++ b/tests/pytorch/kernel/test_gemm_fp8.py @@ -0,0 +1,193 @@ +import pytest +import torch + + +def _make_A(M, K, group_size, out_dtype): + quant_A = torch.rand(M, + K // group_size, + group_size, + dtype=torch.float32, + device='cuda') + # -1 ~ 1 + quant_A = quant_A * 2 - 1 + # scaling abs max to fmax + finfo = torch.finfo(out_dtype) + fmax = finfo.max + scaling = fmax / quant_A.abs().amax(-1, keepdim=True) + quant_A *= scaling + quant_A = quant_A.to(out_dtype).to(torch.float32) + + # create scale and A + scale = torch.rand(M, K // group_size, dtype=torch.float32, device='cuda') + scale /= fmax + A = quant_A * scale[..., None] + + A = A.reshape(M, K) + quant_A = quant_A.reshape(M, K).to(out_dtype) + return A, quant_A, scale + + +def _aligned_size(a, b): + return (a + b - 1) // b * b + + +def _make_B(K, N, group_size, out_dtype): + K_aligned = _aligned_size(K, group_size) + N_aligned = _aligned_size(N, group_size) + + quant_B = torch.rand(K_aligned // group_size, + group_size, + N_aligned // group_size, + group_size, + dtype=torch.float32, + device='cuda') + quant_B = quant_B * 2 - 1 + + # scaling abs max to fmax + finfo = torch.finfo(out_dtype) + fmax = finfo.max + scaling = fmax / quant_B.abs().amax((1, 3), keepdim=True) + quant_B *= scaling + quant_B = quant_B.to(out_dtype).to(torch.float32) + + scale = torch.rand(K_aligned // group_size, + 1, + N_aligned // group_size, + 1, + dtype=torch.float32, + device='cuda') + scale /= fmax + + B = quant_B * scale + + B = B.reshape(K_aligned, N_aligned)[:K, :N] + quant_B = quant_B.reshape(K_aligned, N_aligned).to(out_dtype)[:K, :N] + scale = scale.reshape(K_aligned // group_size, N_aligned // group_size) + return B, quant_B, scale + + +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, + reason='require device with cc>=9.0') +class TestQuantFP8: + + @pytest.fixture + def M(self): + yield 256 + + @pytest.fixture + def K(self): + yield 512 + + @pytest.fixture + def group_size(self): + yield 128 + + @pytest.fixture + def out_dtype(self): + yield torch.float8_e4m3fn + + @pytest.fixture + def build_A(self, M, K, group_size, out_dtype): + return _make_A(M, K, group_size, out_dtype) + + @pytest.fixture + def A(self, build_A): + return build_A[0] + + @pytest.fixture + def quant_A(self, build_A): + return build_A[1] + + @pytest.fixture + def scale(self, build_A): + return build_A[2] + + @pytest.fixture + def gt(self, quant_A, scale): + yield quant_A, scale + + def test_quant_fp8(self, A, group_size, out_dtype, gt): + from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8 + quant_A_gt, scale_gt = gt + + quant_A, scale = quant_fp8(A, group_size=group_size, dtype=out_dtype) + torch.testing.assert_close(scale, scale_gt) + diff = (quant_A.to(torch.float16) - quant_A_gt.to(torch.float16)).abs() + diff_count = (diff > 1e-5).count_nonzero() + assert diff_count / diff.numel() < 1e-4 + + +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, + reason='require device with cc>=9.0') +class TestGemmFP8: + + @pytest.fixture + def M(self): + yield 256 + + @pytest.fixture + def N(self): + # test non-aligned + yield 1024 + 64 + + @pytest.fixture + def K(self): + yield 512 + + @pytest.fixture + def group_size(self): + yield 128 + + @pytest.fixture + def quant_dtype(self): + yield torch.float8_e4m3fn + + @pytest.fixture + def out_dtype(self): + yield torch.float16 + + @pytest.fixture + def build_A(self, M, K, group_size, quant_dtype): + return _make_A(M, K, group_size, quant_dtype) + + @pytest.fixture + def A(self, build_A, out_dtype): + return build_A[0].to(out_dtype) + + @pytest.fixture + def quant_A(self, build_A): + return build_A[1] + + @pytest.fixture + def scale_A(self, build_A): + return build_A[2] + + @pytest.fixture + def build_B(self, K, N, group_size, quant_dtype): + return _make_B(K, N, group_size, quant_dtype) + + @pytest.fixture + def B(self, build_B, out_dtype): + return build_B[0].to(out_dtype) + + @pytest.fixture + def quant_B(self, build_B): + return build_B[1] + + @pytest.fixture + def scale_B(self, build_B): + return build_B[2] + + @pytest.fixture + def gt(self, A, B): + yield A @ B + + def test_gemm_fp8(self, quant_A, scale_A, quant_B, scale_B, out_dtype, gt): + from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import \ + blocked_gemm_fp8 + C = blocked_gemm_fp8(quant_A, + scale_A, + quant_B, + scale_B, + out_dtype=out_dtype) + torch.testing.assert_close(C, gt, atol=0.5, rtol=1e-4) diff --git a/tests/test_lmdeploy/test_model.py b/tests/test_lmdeploy/test_model.py index 3b78053a74..0e53283a87 100644 --- a/tests/test_lmdeploy/test_model.py +++ b/tests/test_lmdeploy/test_model.py @@ -220,7 +220,7 @@ def test_llama3_1(): }, }] actual_prompt = model.messages2prompt(messages, tools=tools) - expected_prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 23 Jul 2024\n\n# Tool Instructions\n- Always execute python code in messages that you share.\n- When looking for real time information use relevant functions if available else fallback to brave_search\n\n\n\nYou have access to the following functions:\n\nUse the function \'spotify_trending_songs\' to: Get top trending songs on Spotify\n{"name": "spotify_trending_songs", "description": "Get top trending songs on Spotify", "parameters": {"n": {"param_type": "int", "description": "Number of trending songs to get", "required": true}}}\n\n\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => ` a JSON dict with the function argument name as key and function argument value as value.\nend_tag => ``\n\nHere is an example,\n{"example_name": "example_value"}\n\nReminder:\n- Function calls MUST follow the specified format\n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line"\n- Always add your sources when using search results to answer the user query\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCan you check the top 5 trending songs on spotify?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa + expected_prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n# Tool Instructions\n- Always execute python code in messages that you share.\n- When looking for real time information use relevant functions if available else fallback to brave_search\n\n\n\nYou have access to the following functions:\n\nUse the function \'spotify_trending_songs\' to: Get top trending songs on Spotify\n{"name": "spotify_trending_songs", "description": "Get top trending songs on Spotify", "parameters": {"n": {"param_type": "int", "description": "Number of trending songs to get", "required": true}}}\n\n\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => ` a JSON dict with the function argument name as key and function argument value as value.\nend_tag => ``\n\nHere is an example,\n{"example_name": "example_value"}\n\nReminder:\n- Function calls MUST follow the specified format\n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line"\n- Always add your sources when using search results to answer the user query\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCan you check the top 5 trending songs on spotify?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa assert actual_prompt == expected_prompt diff --git a/tests/test_lmdeploy/test_vl_encode.py b/tests/test_lmdeploy/test_vl/test_vl_encode.py similarity index 100% rename from tests/test_lmdeploy/test_vl_encode.py rename to tests/test_lmdeploy/test_vl/test_vl_encode.py diff --git a/tests/test_lmdeploy/test_vl_template.py b/tests/test_lmdeploy/test_vl_template.py deleted file mode 100644 index cf8abf9e44..0000000000 --- a/tests/test_lmdeploy/test_vl_template.py +++ /dev/null @@ -1,132 +0,0 @@ -import PIL - -from lmdeploy.model import MODELS -from lmdeploy.vl.constants import IMAGE_TOKEN -from lmdeploy.vl.templates import VLChatTemplateWrapper - - -def test_prompt_to_messages(): - model = MODELS.get('llava-v1')() - templtae = VLChatTemplateWrapper(model) - out = templtae.prompt_to_messages('hi') - assert isinstance(out, list) and isinstance(out[0], dict) - im = PIL.Image.new(mode='RGB', size=(200, 200)) - out = templtae.prompt_to_messages(('hi', [im])) - assert isinstance(out, list) and isinstance(out[0], dict) - - -def test_messages2prompt(): - model = MODELS.get('llava-v1')() - templtae = VLChatTemplateWrapper(model) - messages = [ - dict(role='user', - content=[ - dict(type='text', text='q1'), - dict(type='image_url', image_url=dict(url='xxx')) - ]) - ] - prompt = templtae.messages2prompt(messages) - assert isinstance(prompt, str) - assert prompt.count(IMAGE_TOKEN) == 1 - expected = ( - 'A chat between a curious human and an artificial intelligence ' - 'assistant. The assistant gives helpful, detailed, and polite ' - "answers to the human's questions. USER: " - '\nq1 ASSISTANT:') - assert prompt == expected - - messages.append({'role': 'assistant', 'content': 'a1'}) - messages.append({'role': 'user', 'content': 'q2'}) - prompt = templtae.messages2prompt(messages) - expected = ( - 'A chat between a curious human and an artificial intelligence ' - 'assistant. The assistant gives helpful, detailed, and polite ' - "answers to the human's questions. USER: " - '\nq1 ASSISTANT: a1USER: q2 ASSISTANT:') - assert prompt == expected - - -def test_internvl2_conv(): - # https://huggingface.co/OpenGVLab/InternVL2-8B/blob/3bfd3664dea4f3da628785f5125d30f889701253/conversation.py - from transformers.dynamic_module_utils import get_class_from_dynamic_module - get_conv_template = get_class_from_dynamic_module( - 'conversation.get_conv_template', 'OpenGVLab/InternVL2-8B') - template = get_conv_template('internlm2-chat') - question1 = 'question1' - template.append_message(template.roles[0], question1) - template.append_message(template.roles[1], None) - model = MODELS.get('internvl2-internlm2')() - messages = [dict(role='user', content=question1)] - assert template.get_prompt() == model.messages2prompt(messages) - - answer1 = 'answer1' - template.messages[-1][1] = answer1 - question2 = 'question2' - template.append_message(template.roles[0], question2) - template.append_message(template.roles[1], None) - messages.append(dict(role='assistant', content=answer1)) - messages.append(dict(role='user', content=question2)) - assert template.get_prompt() == model.messages2prompt(messages) - - -def test_llava_conv_chatml_direct(): - model = MODELS.get('llava-chatml')() - templtae = VLChatTemplateWrapper(model) - messages = [ - dict(role='user', - content=[ - dict(type='text', text='q1'), - dict(type='image_url', image_url=dict(url='xxx')) - ]) - ] - - prompt = templtae.messages2prompt(messages) - expected = ('<|im_start|>system\nAnswer the questions.<|im_end|>' - '<|im_start|>user\n\nq1<|im_end|>' - '<|im_start|>assistant\n') - assert prompt == expected - - messages.append({'role': 'assistant', 'content': 'a1'}) - messages.append({'role': 'user', 'content': 'q2'}) - prompt = templtae.messages2prompt(messages) - expected = ('<|im_start|>system\nAnswer the questions.<|im_end|>' - '<|im_start|>user\n\nq1<|im_end|>' - '<|im_start|>assistant\na1<|im_end|>' - '<|im_start|>user\nq2<|im_end|>' - '<|im_start|>assistant\n') - assert prompt == expected - - -def test_custom_image_token(): - from lmdeploy.vl.templates import DeepSeekVLChatTemplateWrapper - model = MODELS.get('deepseek-vl')() - template = DeepSeekVLChatTemplateWrapper(model) - - def create_user(query: str): - item = dict(role='user', content=[dict(type='text', text=query)]) - num = query.count(IMAGE_TOKEN) - for _ in range(num): - item['content'].append( - dict(type='image_url', image_url=dict(url='xxx'))) - return item - - def create_assistant(response: str): - return dict(role='assistant', content=response) - - messages = [create_user(f'{IMAGE_TOKEN} q1')] - prompt = template.messages2prompt(messages) - expected = ('You are a helpful language and vision assistant. You are able' - ' to understand the visual content that the user provides, and' - ' assist the user with a variety of tasks using natural ' - 'language.\n\nUser: q1\n\nAssistant:') - assert prompt == expected - - messages.append(create_assistant('a1')) - messages.append(create_user(f'q2 {IMAGE_TOKEN}')) - prompt = template.messages2prompt(messages) - expected = ('You are a helpful language and vision assistant. You are able' - ' to understand the visual content that the user provides, and' - ' assist the user with a variety of tasks using natural ' - 'language.\n\nUser: q1\n\nAssistant: ' - 'a1<|end▁of▁sentence|>User: q2 \n\nAssistant:') - assert prompt == expected