diff --git a/.github/workflows/pr_ete_test.yml b/.github/workflows/pr_ete_test.yml index 3a19ebe870..2d1c4b63f5 100644 --- a/.github/workflows/pr_ete_test.yml +++ b/.github/workflows/pr_ete_test.yml @@ -10,7 +10,7 @@ on: - "3rdparty/**" - "lmdeploy/**" - "requirements/**" - - "requirements.txt" + - "requirements_cuda.txt" - "CMakeLists.txt" - "setup.py" workflow_dispatch: @@ -68,7 +68,7 @@ jobs: export PATH=$PATH:/usr/local/openmpi/bin export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/openmpi/lib python3 -m pip install cmake packaging wheel transformers_stream_generator transformers datasets openai einops timm decord - python3 -m pip install -r requirements.txt -r requirements/test.txt -r requirements/build.txt + python3 -m pip install -r requirements_cuda.txt -r requirements/test.txt -r requirements/build.txt mkdir -p build && cd build &&\ sh ../generate.sh &&\ ninja -j$(nproc) && ninja install &&\ diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index ec6db0682d..3a459050ec 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -10,7 +10,7 @@ on: - "3rdparty/**" - "lmdeploy/**" - "requirements/**" - - "requirements.txt" + - "requirements_cuda.txt" - "CMakeLists.txt" - "setup.py" push: @@ -24,7 +24,7 @@ on: - "3rdparty/**" - "lmdeploy/**" - "requirements/**" - - "requirements.txt" + - "requirements_cuda.txt" - "CMakeLists.txt" - "setup.py" tags: @@ -39,6 +39,7 @@ jobs: options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e CUDA_VISIBLE_DEVICES=2,3 --pull never" volumes: - /nvme/share_data/github-actions/pip-cache:/root/.cache/pip + - /nvme/share_data/github-actions/hf_home:/root/.cache/huggingface - /nvme/share_data/github-actions/packages:/root/packages - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: @@ -78,7 +79,7 @@ jobs: python3 -m pip install pynvml packaging protobuf transformers_stream_generator # manually install flash attn python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp38-cp38-linux_x86_64.whl - python3 -m pip install -r requirements.txt -r requirements/test.txt + python3 -m pip install -r requirements_cuda.txt -r requirements/test.txt python3 -m pip install . - name: Check env run: | diff --git a/README.md b/README.md index d160338aa6..8ef7b7994f 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,8 @@ For detailed inference benchmarks in more devices and more settings, please refe
  • Qwen1.5 (0.5B - 110B)
  • Qwen1.5 - MoE (0.5B - 72B)
  • Qwen2 (0.5B - 72B)
  • +
  • Qwen2-MoE (57BA14B)
  • +
  • Qwen2.5 (0.5B - 32B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • @@ -136,6 +138,7 @@ For detailed inference benchmarks in more devices and more settings, please refe
  • Mistral (7B)
  • DeepSeek-MoE (16B)
  • DeepSeek-V2 (16B, 236B)
  • +
  • DeepSeek-V2.5 (236B)
  • Mixtral (8x7B, 8x22B)
  • Gemma (2B - 7B)
  • Dbrx (132B)
  • diff --git a/README_ja.md b/README_ja.md index fda176229e..77badaac36 100644 --- a/README_ja.md +++ b/README_ja.md @@ -122,6 +122,8 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
  • Qwen1.5 (0.5B - 110B)
  • Qwen1.5 - MoE (0.5B - 72B)
  • Qwen2 (0.5B - 72B)
  • +
  • Qwen2-MoE (57BA14B)
  • +
  • Qwen2.5 (0.5B - 32B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • @@ -133,6 +135,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
  • Mistral (7B)
  • DeepSeek-MoE (16B)
  • DeepSeek-V2 (16B, 236B)
  • +
  • DeepSeek-V2.5 (236B)
  • Mixtral (8x7B, 8x22B)
  • Gemma (2B - 7B)
  • Dbrx (132B)
  • diff --git a/README_zh-CN.md b/README_zh-CN.md index 6c24b2e500..9f3cd40a64 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -126,6 +126,8 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
  • Qwen1.5 (0.5B - 110B)
  • Qwen1.5 - MoE (0.5B - 72B)
  • Qwen2 (0.5B - 72B)
  • +
  • Qwen2-MoE (57BA14B)
  • +
  • Qwen2.5 (0.5B - 32B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • @@ -137,6 +139,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
  • Mistral (7B)
  • DeepSeek-MoE (16B)
  • DeepSeek-V2 (16B, 236B)
  • +
  • DeepSeek-V2.5 (236B)
  • Mixtral (8x7B, 8x22B)
  • Gemma (2B - 7B)
  • Dbrx (132B)
  • diff --git a/autotest/utils/config_utils.py b/autotest/utils/config_utils.py index 24b4a3f8cd..dd8db1ccc4 100644 --- a/autotest/utils/config_utils.py +++ b/autotest/utils/config_utils.py @@ -97,7 +97,7 @@ def get_all_model_list(tp_num: int = None, model_type=model_type): if case not in case_list: case_list.append(case) - return [x for x in case_list if 'w8a8' not in x] + return case_list def get_quantization_model_list(type): diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index d41f886227..3a343e8f5b 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -377,15 +377,17 @@ def main(): requests = sample_requests(args.dataset, args.num_prompts, engine.tokenizer) - engine.process_request(requests, - temperature=args.temperature, - top_p=args.top_p, - top_k=args.top_k, - concurrency=args.concurrency, - stream_output=not args.no_stream_output, - skip_tokenize=args.skip_tokenize, - skip_detokenize=args.skip_detokenize, - cancel_rate=args.cancel_rate) + engine.process_request( + requests, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + concurrency=args.concurrency + if args.concurrency < args.num_prompts else args.num_prompts, + stream_output=not args.no_stream_output, + skip_tokenize=args.skip_tokenize, + skip_detokenize=args.skip_detokenize, + cancel_rate=args.cancel_rate) if __name__ == '__main__': diff --git a/docker/Dockerfile b/docker/Dockerfile index caa58ee637..24b2b055da 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -10,9 +10,6 @@ FROM ${CUDA_VERSION} AS final ARG PYTHON_VERSION=3.10 -ARG TORCH_VERSION=2.3.0 -ARG TORCHVISION_VERSION=0.18.0 - RUN apt-get update -y && apt-get install -y software-properties-common wget vim git curl openssh-server ssh sudo &&\ curl https://sh.rustup.rs -sSf | sh -s -- -y &&\ add-apt-repository ppa:deadsnakes/ppa -y && apt-get update -y && apt-get install -y --no-install-recommends \ @@ -43,7 +40,6 @@ ENV LD_LIBRARY_PATH=/usr/local/nccl/lib:$LD_LIBRARY_PATH RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install --upgrade pip setuptools==69.5.1 &&\ - python3 -m pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} --index-url https://download.pytorch.org/whl/${CUDA_VERSION_SHORT} &&\ python3 -m pip install cmake packaging wheel ENV NCCL_LAUNCH_MODE=GROUP @@ -54,7 +50,7 @@ COPY . /opt/lmdeploy WORKDIR /opt/lmdeploy RUN --mount=type=cache,target=/root/.cache/pip cd /opt/lmdeploy &&\ - python3 -m pip install -r requirements.txt &&\ + python3 -m pip install -r requirements_cuda.txt --extra-index-url https://download.pytorch.org/whl/${CUDA_VERSION_SHORT} &&\ mkdir -p build && cd build &&\ sh ../generate.sh &&\ ninja -j$(nproc) && ninja install &&\ diff --git a/docker/Dockerfile_aarch64_ascend b/docker/Dockerfile_aarch64_ascend index 1c9591197b..ecc2d1334e 100644 --- a/docker/Dockerfile_aarch64_ascend +++ b/docker/Dockerfile_aarch64_ascend @@ -122,4 +122,4 @@ WORKDIR /opt/lmdeploy RUN --mount=type=cache,target=/root/.cache/pip \ sed -i '/triton/d' requirements/runtime.txt && \ - pip3 install -v --no-build-isolation -e . + LMDEPLOY_TARGET_DEVICE=ascend pip3 install -v --no-build-isolation -e . diff --git a/docs/en/get_started/ascend/get_started.md b/docs/en/get_started/ascend/get_started.md index 23b86afa61..d104477ca1 100644 --- a/docs/en/get_started/ascend/get_started.md +++ b/docs/en/get_started/ascend/get_started.md @@ -136,3 +136,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu ``` Please check [supported_models](../../supported_models/supported_models.md) before use this feature. + +### int8 KV-cache Quantization + +Ascend backend has supported offline int8 KV-cache Quantization on eager mode. + +Please refer this [doc](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md) for details. diff --git a/docs/en/get_started/installation.md b/docs/en/get_started/installation.md index b3e8bb8abd..c00111c2ab 100644 --- a/docs/en/get_started/installation.md +++ b/docs/en/get_started/installation.md @@ -23,7 +23,7 @@ pip install lmdeploy The default prebuilt package is compiled on **CUDA 12**. If CUDA 11+ (>=11.3) is required, you can install lmdeploy by: ```shell -export LMDEPLOY_VERSION=0.6.3 +export LMDEPLOY_VERSION=0.6.4 export PYTHON_VERSION=38 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 ``` diff --git a/docs/en/llm/api_server.md b/docs/en/llm/api_server.md index 285b0e32ff..274ec2ff25 100644 --- a/docs/en/llm/api_server.md +++ b/docs/en/llm/api_server.md @@ -249,6 +249,57 @@ curl http://{server_ip}:{server_port}/v1/chat/interactive \ lmdeploy serve gradio api_server_url --server-name ${gradio_ui_ip} --server-port ${gradio_ui_port} ``` +## Launch multiple api servers + +Following are two steps to launch multiple api servers through torchrun. Just create a python script with the following codes. + +1. Launch the proxy server through `lmdeploy serve proxy`. Get the correct proxy server url. +2. Launch the script through `torchrun --nproc_per_node 2 script.py InternLM/internlm2-chat-1_8b --proxy_url http://{proxy_node_name}:{proxy_node_port}`.**Note**: Please do not use `0.0.0.0:8000` here, instead, we input the real ip name, `11.25.34.55:8000` for example. + +```python +import os +import socket +from typing import List, Literal + +import fire + + +def get_host_ip(): + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(('8.8.8.8', 80)) + ip = s.getsockname()[0] + finally: + s.close() + return ip + + +def main(model_path: str, + tp: int = 1, + proxy_url: str = 'http://0.0.0.0:8000', + port: int = 23333, + backend: Literal['turbomind', 'pytorch'] = 'turbomind'): + local_rank = int(os.environ.get('LOCAL_RANK', -1)) + world_size = int(os.environ.get('WORLD_SIZE', -1)) + local_ip = get_host_ip() + if isinstance(port, List): + assert len(port) == world_size + port = port[local_rank] + else: + port += local_rank * 10 + if (world_size - local_rank) % tp == 0: + rank_list = ','.join([str(local_rank + i) for i in range(tp)]) + command = f'CUDA_VISIBLE_DEVICES={rank_list} lmdeploy serve api_server {model_path} '\ + f'--server-name {local_ip} --server-port {port} --tp {tp} '\ + f'--proxy-url {proxy_url} --backend {backend}' + print(f'running command: {command}') + os.system(command) + + +if __name__ == '__main__': + fire.Fire(main) +``` + ## FAQ 1. When user got `"finish_reason":"length"`, it means the session is too long to be continued. The session length can be diff --git a/docs/en/multi_modal/llava.md b/docs/en/multi_modal/llava.md index 8f052227d5..c374b67121 100644 --- a/docs/en/multi_modal/llava.md +++ b/docs/en/multi_modal/llava.md @@ -6,11 +6,17 @@ LMDeploy supports the following llava series of models, which are detailed in th | :----------------------------------: | :--: | :------------------------: | | llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch | | llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch | -| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind, PyTorch | -| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind, PyTorch | +| llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch | +| llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch | +| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind | +| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind | The next chapter demonstrates how to deploy an Llava model using LMDeploy, with [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) as an example. +```{note} +PyTorch engine removes the support of original llava models after v0.6.4. Please use their corresponding transformers models instead, which can be found in https://huggingface.co/llava-hf +``` + ## Installation Please install LMDeploy by following the [installation guide](../get_started/installation.md). diff --git a/docs/en/multi_modal/qwen2_vl.md b/docs/en/multi_modal/qwen2_vl.md index 8b59f84545..fd9f02abaa 100644 --- a/docs/en/multi_modal/qwen2_vl.md +++ b/docs/en/multi_modal/qwen2_vl.md @@ -4,7 +4,7 @@ LMDeploy supports the following Qwen-VL series of models, which are detailed in | Model | Size | Supported Inference Engine | | :----------: | :----: | :------------------------: | -| Qwen-VL-Chat | - | TurboMind, Pytorch | +| Qwen-VL-Chat | - | TurboMind | | Qwen2-VL | 2B, 7B | PyTorch | The next chapter demonstrates how to deploy an Qwen-VL model using LMDeploy, with [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) as an example. diff --git a/docs/en/quantization/w4a16.md b/docs/en/quantization/w4a16.md index 0aa1e17a5b..c36c3736c6 100644 --- a/docs/en/quantization/w4a16.md +++ b/docs/en/quantization/w4a16.md @@ -128,3 +128,7 @@ We benchmarked the Llama-2-7B-chat and Llama-2-13B-chat models with 4-bit quanti | ---------------- | ------- | ------- | --------- | | Llama-2-7B-chat | 112.9 | 159.4 | 206.4 | | Llama-2-13B-chat | N/A | 90.7 | 115.8 | + +## FAQs + +1. Out of Memory error during quantization due to insufficient GPU memory: This can be addressed by reducing the parameter `--calib-seqlen`, increasing the parameter `--calib-samples`, and set `--batch-size` to 1. diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index 469ece487f..dd8ceb4ffa 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -10,7 +10,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine | Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | | Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | | Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | -| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | +| Llama3.2 | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes | | InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | | InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | | InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | @@ -18,9 +18,13 @@ The following tables detail the models supported by LMDeploy's TurboMind engine | InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes | | Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | | Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes | -| Qwen2 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes | +| Qwen2 | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes | +| Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes | +| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes | | Mistral | 7B | LLM | Yes | Yes | Yes | No | | Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes | +| DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No | +| DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No | | Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes | | DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes | | Baichuan | 7B | LLM | Yes | Yes | Yes | Yes | @@ -29,7 +33,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine | YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | | LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes | | InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes | -| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes | Yes | Yes | +| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes | | ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes | | MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes | | MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes | @@ -41,7 +45,8 @@ The following tables detail the models supported by LMDeploy's TurboMind engine "-" means not verified yet. ```{note} -The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference. +* The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference. +* When the head_dim of a model is not 128, such as llama3.2-1B, qwen2-0.5B and internvl2-1B, turbomind doesn't support its kv cache 4/8 bit quantization and inference ``` ## PyTorchEngine on CUDA Platform @@ -68,11 +73,13 @@ The TurboMind engine doesn't support window attention. Therefore, for models tha | QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes | | QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No | | QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | +| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | | QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | No | | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | | DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | +| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No | | MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No | -| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | Yes | Yes | +| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes | | Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No | | Dbrx | 132B | LLM | Yes | Yes | Yes | No | No | | StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No | @@ -81,7 +88,7 @@ The TurboMind engine doesn't support window attention. Therefore, for models tha | CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - | | CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - | | LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | - | - | -| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | Yes | Yes | +| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes | | InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | - | - | | Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | - | - | | ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - | diff --git a/docs/zh_cn/get_started/ascend/get_started.md b/docs/zh_cn/get_started/ascend/get_started.md index b137c458be..9f0a7b1f90 100644 --- a/docs/zh_cn/get_started/ascend/get_started.md +++ b/docs/zh_cn/get_started/ascend/get_started.md @@ -133,3 +133,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu ``` 支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)。 + +### int8 KV-cache 量化 + +昇腾后端现在支持了在eager模式下的离线int8 KV-cache量化。 + +详细使用方式请请参考这篇[文章](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md)。 diff --git a/docs/zh_cn/get_started/installation.md b/docs/zh_cn/get_started/installation.md index 12562c51d5..0213fa6d15 100644 --- a/docs/zh_cn/get_started/installation.md +++ b/docs/zh_cn/get_started/installation.md @@ -23,7 +23,7 @@ pip install lmdeploy 默认的预构建包是在 **CUDA 12** 上编译的。如果需要 CUDA 11+ (>=11.3),你可以使用以下命令安装 lmdeploy: ```shell -export LMDEPLOY_VERSION=0.6.3 +export LMDEPLOY_VERSION=0.6.4 export PYTHON_VERSION=38 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 ``` diff --git a/docs/zh_cn/llm/api_server.md b/docs/zh_cn/llm/api_server.md index d6c0c42aef..8bb91c619e 100644 --- a/docs/zh_cn/llm/api_server.md +++ b/docs/zh_cn/llm/api_server.md @@ -258,6 +258,89 @@ curl http://{server_ip}:{server_port}/v1/chat/interactive \ }' ``` +## 同时启动多个 api_server + +两步直接启动多机多卡服务。先用下面的代码创建一个启动脚本。然后: + +1. 启动代理服务 `lmdeploy serve proxy`。 +2. torchrun 启动脚本 `torchrun --nproc_per_node 2 script.py InternLM/internlm2-chat-1_8b --proxy_url http://{proxy_node_name}:{proxy_node_port}`. **注意**: 多机多卡不要用默认 url `0.0.0.0:8000`,我们需要输入真实ip对应的地址,如:`11.25.34.55:8000`。多机情况下,因为不需要子节点间的通信,所以并不需要用户指定 torchrun 的 `--nnodes` 等参数,只要能保证每个节点执行一次单节点的 torchrun 就行。 + +```python +import os +import socket +from typing import List, Literal + +import fire + + +def get_host_ip(): + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(('8.8.8.8', 80)) + ip = s.getsockname()[0] + finally: + s.close() + return ip + + +def main(model_path: str, + tp: int = 1, + proxy_url: str = 'http://0.0.0.0:8000', + port: int = 23333, + backend: Literal['turbomind', 'pytorch'] = 'turbomind'): + local_rank = int(os.environ.get('LOCAL_RANK', -1)) + world_size = int(os.environ.get('WORLD_SIZE', -1)) + local_ip = get_host_ip() + if isinstance(port, List): + assert len(port) == world_size + port = port[local_rank] + else: + port += local_rank * 10 + if (world_size - local_rank) % tp == 0: + rank_list = ','.join([str(local_rank + i) for i in range(tp)]) + command = f'CUDA_VISIBLE_DEVICES={rank_list} lmdeploy serve api_server {model_path} '\ + f'--server-name {local_ip} --server-port {port} --tp {tp} '\ + f'--proxy-url {proxy_url} --backend {backend}' + print(f'running command: {command}') + os.system(command) + + +if __name__ == '__main__': + fire.Fire(main) +``` + +### 示例 + +为了进一步展示如何在集群环境中使用多机多卡服务。下面提供一个在火山云的用例: + +```shell +#!/bin/bash +# 激活 conda 环境 +source /path/to/your/home/miniconda3/bin/activate /path/to/your/home/miniconda3/envs/your_env +export HOME=/path/to/your/home +# 获取主节点IP地址(假设 MLP_WORKER_0_HOST 是主节点的IP) +MASTER_IP=${MLP_WORKER_0_HOST} +# 检查是否为主节点 +if [ "${MLP_ROLE_INDEX}" -eq 0 ]; then + # 启动 lmdeploy serve proxy 并放入后台 + echo "Starting lmdeploy serve proxy on master node..." + PROXY_PORT=8000 + lmdeploy serve proxy --server-name ${MASTER_IP} --server-port ${PROXY_PORT} & +else + # 这里我们默认调度平台同时启动了所有机器,否则要sleep一会,等待 proxy 启动成功 + echo "Not starting lmdeploy serve proxy on worker node ${MLP_ROLE_INDEX}." +fi +# 启动 torchrun 并放入后台 +# 再次强调多机环境下并不需要传--nnodes 或者 --master-addr 等参数,相当于每个机器上执行一次单节点的 torchrun 即可。 +torchrun \ +--nproc_per_node=${MLP_WORKER_GPU} \ +/path/to/script.py \ +InternLM/internlm2-chat-1_8b 8 http://${MASTER_IP}:${PROXY_PORT} +# 打印主机的IP地址 +echo "Host IP addresses:" +hostname -I +``` + ## 接入 WebUI LMDeploy 提供 gradio 和 [OpenAOE](https://github.com/InternLM/OpenAOE) 两种方式,为 api_server 接入 WebUI。 diff --git a/docs/zh_cn/multi_modal/llava.md b/docs/zh_cn/multi_modal/llava.md index c40f37308a..6538d1b861 100644 --- a/docs/zh_cn/multi_modal/llava.md +++ b/docs/zh_cn/multi_modal/llava.md @@ -6,11 +6,17 @@ LMDeploy 支持以下 LLaVA 系列模型,具体如下表所示: | :----------------------------------: | :--: | :----------------: | | llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch | | llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch | -| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind, PyTorch | -| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind, PyTorch | +| llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch | +| llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch | +| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind | +| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind | 接下来的章节将演示如何使用 LMDeploy 部署 LLaVA 模型,并以 [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) 为例。 +```{note} +自 0.6.4 之后,PyTorch 引擎移除了对 llava 原始模型的支持。我们建议使用它们对应的 transformers 格式的模型。这些模型可以在 https://huggingface.co/llava-hf 中找到 +``` + ## 安装 请按照[安装指南](../get_started/installation.md)安装 LMDeploy。 diff --git a/docs/zh_cn/multi_modal/qwen2_vl.md b/docs/zh_cn/multi_modal/qwen2_vl.md index f62d2de74c..7cb7efe93b 100644 --- a/docs/zh_cn/multi_modal/qwen2_vl.md +++ b/docs/zh_cn/multi_modal/qwen2_vl.md @@ -4,7 +4,7 @@ LMDeploy 支持 Qwen-VL 系列模型,具体如下: | Model | Size | Supported Inference Engine | | :----------: | :----: | :------------------------: | -| Qwen-VL-Chat | - | TurboMind, Pytorch | +| Qwen-VL-Chat | - | TurboMind | | Qwen2-VL | 2B, 7B | PyTorch | 本文将以[Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)为例,演示使用 LMDeploy 部署 Qwen2-VL 系列模型的方法 diff --git a/docs/zh_cn/quantization/w4a16.md b/docs/zh_cn/quantization/w4a16.md index d69a8a23d2..3cea164dd9 100644 --- a/docs/zh_cn/quantization/w4a16.md +++ b/docs/zh_cn/quantization/w4a16.md @@ -131,3 +131,8 @@ lmdeploy serve api_client http://0.0.0.0:23333 | ---------------- | ------- | ------- | --------- | | Llama-2-7B-chat | 112.9 | 159.4 | 206.4 | | Llama-2-13B-chat | N/A | 90.7 | 115.8 | + +## 快速问答 + +1. 量化时出现 Out of Memory 显存不够:可以通过减小传参 `--calib-seqlen`,增大传参 `--calib-samples`,并使用 `--batch-size` 为 1。 +2. 量化时,无法链接huggingface并下载数据集。可以尝试使用镜像,`export HF_ENDPOINT=https://hf-mirror.com`。 diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index d734523282..3ec3688e1b 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -10,7 +10,7 @@ | Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | | Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | | Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | -| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | +| Llama3.2 | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes | | InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | | InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | | InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | @@ -18,9 +18,13 @@ | InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes | | Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | | Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes | -| Qwen2 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes | +| Qwen2 | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes | +| Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes | +| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes | | Mistral | 7B | LLM | Yes | Yes | Yes | No | | Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes | +| DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No | +| DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No | | Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes | | DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes | | Baichuan | 7B | LLM | Yes | Yes | Yes | Yes | @@ -29,7 +33,7 @@ | YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | | LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes | | InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes | -| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes | Yes | Yes | +| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes | | ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes | | MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes | | MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes | @@ -41,7 +45,8 @@ “-” 表示还没有验证。 ```{note} -turbomind 引擎不支持 window attention。所以,对于应用了 window attention,并开启了对应的开关"use_sliding_window"的模型,比如 Mistral、Qwen1.5 等,在推理时,请选择 pytorch engine +* turbomind 引擎不支持 window attention。所以,对于应用了 window attention,并开启了对应的开关"use_sliding_window"的模型,比如 Mistral、Qwen1.5 等,在推理时,请选择 pytorch engine +* 当模型的 head_dim 非 128 时,turbomind 不支持它的 kv cache 4/8 bit 量化和推理。比如,llama3.2-1B,qwen2-0.5B,internvl2-1B 等等 ``` ## PyTorchEngine CUDA 平台 @@ -68,11 +73,13 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att | QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes | | QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No | | QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | +| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | | QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | No | | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | | DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | +| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No | | MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No | -| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | Yes | Yes | +| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes | | Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No | | Dbrx | 132B | LLM | Yes | Yes | Yes | No | No | | StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No | @@ -81,7 +88,7 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att | CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - | | CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - | | LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | - | - | -| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | Yes | Yes | +| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes | | InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | - | - | | Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | - | - | | ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - | @@ -94,7 +101,7 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att | Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - | ```{note} -* Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead. +* 目前,Mono-InternVL不支持FP16,因为数值不稳定。请改用BF16。 ``` ## PyTorchEngine 华为昇腾平台 diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py index ce5cbd98ff..760a82b1c9 100644 --- a/lmdeploy/archs.py +++ b/lmdeploy/archs.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os -from typing import Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union from transformers import AutoConfig @@ -128,7 +128,8 @@ def check_vl_llm(config: dict) -> bool: return True elif arch == 'MultiModalityCausalLM' and 'language_config' in config: return True - elif arch == 'ChatGLMModel' and 'vision_config' in config: + elif arch in ['ChatGLMModel', 'ChatGLMForConditionalGeneration' + ] and 'vision_config' in config: return True elif arch in supported_archs: return True @@ -193,3 +194,22 @@ def get_model_arch(model_path: str): raise RuntimeError( f'Could not find model architecture from config: {_cfg}') return arch, cfg + + +def search_nested_config(config, key): + """Recursively searches for the value associated with the given key in a + nested configuration of a model.""" + if isinstance(config, Dict): + for k, v in config.items(): + if k == key: + return v + if isinstance(v, (Dict, List)): + result = search_nested_config(v, key) + if result is not None: + return result + elif isinstance(config, List): + for item in config: + result = search_nested_config(item, key) + if result is not None: + return result + return None diff --git a/lmdeploy/cli/lite.py b/lmdeploy/cli/lite.py index d76d6a5f34..236e022b34 100644 --- a/lmdeploy/cli/lite.py +++ b/lmdeploy/cli/lite.py @@ -35,6 +35,7 @@ def add_parser_auto_awq(): ArgumentHelper.calib_seqlen(parser) ArgumentHelper.calib_batchsize(parser) ArgumentHelper.calib_search_scale(parser) + ArgumentHelper.dtype(parser) parser.add_argument( '--device', type=str, @@ -71,6 +72,7 @@ def add_parser_auto_gptq(): ArgumentHelper.calib_samples(parser) ArgumentHelper.calib_seqlen(parser) ArgumentHelper.calib_batchsize(parser) + ArgumentHelper.dtype(parser) parser.add_argument('--w-bits', type=int, default=4, @@ -99,6 +101,7 @@ def add_parser_calibrate(): ArgumentHelper.calib_seqlen(parser) ArgumentHelper.calib_batchsize(parser) ArgumentHelper.calib_search_scale(parser) + ArgumentHelper.dtype(parser) @staticmethod def add_parser_smooth_quant(): @@ -122,6 +125,7 @@ def add_parser_smooth_quant(): ArgumentHelper.calib_seqlen(parser) ArgumentHelper.calib_batchsize(parser) ArgumentHelper.calib_search_scale(parser) + ArgumentHelper.dtype(parser) @staticmethod def auto_awq(args): diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 68f9de8c15..5cf3453b7e 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -238,6 +238,7 @@ def add_parser_proxy(): help='the strategy to dispatch requests to nodes') ArgumentHelper.api_keys(parser) ArgumentHelper.ssl(parser) + ArgumentHelper.log_level(parser) @staticmethod def gradio(args): diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 85784a58f5..cf7b6526ec 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -354,7 +354,7 @@ def calib_batchsize(parser): @staticmethod def calib_search_scale(parser): - """Add argument batch_size to parser.""" + """Add argument search_scale to parser.""" return parser.add_argument( '--search-scale', diff --git a/lmdeploy/lite/apis/auto_awq.py b/lmdeploy/lite/apis/auto_awq.py index c41b28fd6e..2c84612839 100644 --- a/lmdeploy/lite/apis/auto_awq.py +++ b/lmdeploy/lite/apis/auto_awq.py @@ -2,6 +2,7 @@ import os import os.path as osp import shutil +from typing import Literal import torch from torch import nn @@ -12,9 +13,7 @@ from lmdeploy.lite.utils import collect_target_modules from lmdeploy.pytorch.check_env import try_import_deeplink -from .calibrate import LAYER_TYPE_MAP, NORM_TYPE_MAP, calibrate - -NORM_TYPE_MAP = NORM_TYPE_MAP # legacy +from .calibrate import LAYER_TYPE_MAP, calibrate def save_vl_model(vl_model, model_path, dst_path): @@ -56,6 +55,7 @@ def auto_awq(model: str, search_scale: bool = False, device: str = 'cuda', revision: str = None, + dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto', download_dir: str = None): """Perform weight quantization using AWQ algorithm. @@ -77,6 +77,7 @@ def auto_awq(model: str, revision (str): The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. + dtype (str): Data type for loading model weights and calib infer. download_dir (str): Directory to download and load the weights, default to the default cache directory of huggingface. """ @@ -96,6 +97,7 @@ def auto_awq(model: str, w_bits=w_bits, w_group_size=w_group_size, search_scale=search_scale, + dtype=dtype, batch_size=batch_size) layer_type = LAYER_TYPE_MAP[type(model).__name__] diff --git a/lmdeploy/lite/apis/calibrate.py b/lmdeploy/lite/apis/calibrate.py index 71f7a5900c..007f831a70 100644 --- a/lmdeploy/lite/apis/calibrate.py +++ b/lmdeploy/lite/apis/calibrate.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from pathlib import Path -from typing import Union +from typing import Literal, Union import torch from torch import nn @@ -11,6 +11,7 @@ from lmdeploy.lite.quantization import CalibrationContext, CalibrationContextV2 from lmdeploy.lite.utils import (collect_target_modules, get_calib_loaders, load_hf_from_pretrained) +from lmdeploy.vl.model.builder import load_vl_model LAYER_TYPE_MAP = { 'InternLMForCausalLM': 'InternLMDecoderLayer', @@ -204,6 +205,7 @@ def calibrate(model: str, w_bits: int = 4, w_group_size: int = 128, search_scale: bool = False, + dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto', batch_size: int = 1) -> None: """The main function for loading the model and performing calibration on a given dataset. @@ -224,6 +226,7 @@ def calibrate(model: str, w_group_size (int): Group size for weight quantization statistics. search_scale (bool): Whether search scale ratio. Default to False, which means only smooth quant with 0.5 ratio will be applied. + dtype (str): Data type for loading model weights and calib infer. batch_size (int): The batch size for running the calib samples. Low GPU mem requires small batch_size. Large batch_size reduces the calibration time while costs more VRAM. @@ -239,20 +242,35 @@ def calibrate(model: str, model_type, _ = get_task(model) make_compatible_internvl_config(model) - if model_type == 'llm': - # Load tokenizer and configuration - tokenizer = AutoTokenizer.from_pretrained(model, - trust_remote_code=True) + # Load tokenizer and configuration + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + + if model_type == 'llm': model = load_hf_from_pretrained(model, - torch_dtype=torch.float16, + dtype=dtype, trust_remote_code=True) vl_model = None elif model_type == 'vlm': - from lmdeploy.vl.model.builder import vl_model_with_tokenizer - vl_model, model, tokenizer = vl_model_with_tokenizer(model_path=model) + vl_model = load_vl_model(model, backend=None, with_llm=True).vl_model + model = vl_model + if hasattr(vl_model, 'language_model'): # deepseek-vl, ... + model = vl_model.language_model + if hasattr(vl_model, 'llm'): # MiniCPMV, ... + model = vl_model.llm + model.config.use_cache = False + if dtype == 'float16': + model.half() + elif dtype == 'bfloat16': + assert torch.cuda.is_bf16_supported( + ), 'your device does not support bfloat16 please set --dtype float16' # noqa + model.to(torch.bfloat16) + elif dtype == 'auto' and model.config.torch_dtype == torch.bfloat16: + print('Warning: we cast model to float16 to prevent OOM. You' + ' may enforce it bfloat16 by `--dtype bfloat16`') + model.half() + model.eval() - model.config.use_cache = False model_type = type(model).__name__ if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP: raise RuntimeError( diff --git a/lmdeploy/lite/apis/gptq.py b/lmdeploy/lite/apis/gptq.py index 12b88a52cd..eb4418a533 100644 --- a/lmdeploy/lite/apis/gptq.py +++ b/lmdeploy/lite/apis/gptq.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging +from typing import Literal import torch -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from lmdeploy.lite.utils.calib_dataloader import get_calib_loaders @@ -15,6 +16,7 @@ def auto_gptq(model: str, calib_samples: int = 128, calib_seqlen: int = 2048, batch_size: int = 1, + dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto', revision: str = None): """Perform weight quantization using AWQ algorithm. @@ -29,9 +31,7 @@ def auto_gptq(model: str, calib_seqlen (int): The sequence length for calibration. w_bits (int): Bit number for weight quantization. w_group_size (int): Group size for weight quantization statistics. - search_scale (bool): Whether search scale ratio. Default to False, - which means only smooth quant with 0.5 ratio will be applied. - device (str): Device type of running. + dtype (str): Data type for loading model weights and calib infer. revision (str): The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. @@ -83,9 +83,18 @@ def auto_gptq(model: str, # load un-quantized model, by default, # the model will always be loaded into CPU memory + hf_config = AutoConfig.from_pretrained(pretrained_model_dir, + revision=revision, + trust_remote_code=True) + torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16) + if dtype == 'float16': + torch_dtype = torch.float16 + elif dtype == 'bfloat16': + torch_dtype = torch.bfloat16 model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, revision=revision, + torch_dtype=torch_dtype, trust_remote_code=True) # quantize model, the examples should be list of dict whose keys diff --git a/lmdeploy/lite/apis/smooth_quant.py b/lmdeploy/lite/apis/smooth_quant.py index c8df67355e..188eedbd0e 100644 --- a/lmdeploy/lite/apis/smooth_quant.py +++ b/lmdeploy/lite/apis/smooth_quant.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. + +from typing import Literal + import fire import torch from torch import nn @@ -6,7 +9,8 @@ from lmdeploy.lite.apis.calibrate import (LAYER_TYPE_MAP, NORM_TYPE_MAP, calibrate) from lmdeploy.lite.quantization.awq import (FC_FCS_MAP, NORM_FCS_MAP, - awq_layers, smooth_layers) + awq_layers, skipped_module, + smooth_layers) from lmdeploy.lite.utils import collect_target_modules from lmdeploy.pytorch.models import QLinear, QRMSNorm @@ -19,8 +23,8 @@ def smooth_quant(model: str, search_scale: bool = False, batch_size: int = 1, w_bits: int = 8, + dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto', device: str = 'cuda'): - model_path = model vl_model, model, tokenizer, work_dir = calibrate(model, calib_dataset, @@ -31,6 +35,7 @@ def smooth_quant(model: str, w_bits=w_bits, w_group_size=-1, search_scale=search_scale, + dtype=dtype, batch_size=batch_size) # calibrate function exports the calibration statistics @@ -76,6 +81,8 @@ def smooth_quant(model: str, rmsnorms = collect_target_modules(model, norm_type) for name, linear in fcs.items(): + if skipped_module(name): + continue linear.to(device) q_linear = QLinear.from_float(linear) parent_name, _, child_name = name.rpartition('.') @@ -84,6 +91,8 @@ def smooth_quant(model: str, linear.to('cpu') for name, norm in rmsnorms.items(): + if skipped_module(name): + continue norm.to(device) q_norm = QRMSNorm.from_float(norm) parent_name, _, child_name = name.rpartition('.') diff --git a/lmdeploy/lite/quantization/awq.py b/lmdeploy/lite/quantization/awq.py index cf03a75216..3e24a13cc3 100644 --- a/lmdeploy/lite/quantization/awq.py +++ b/lmdeploy/lite/quantization/awq.py @@ -43,8 +43,10 @@ 'MixtralDecoderLayer': { 'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], - 'post_attention_layernorm': - ['block_sparse_moe.experts.{i}.w1', 'block_sparse_moe.experts.{i}.w3'] + 'post_attention_layernorm': [ + 'block_sparse_moe.gate', 'block_sparse_moe.experts.{i}.w1', + 'block_sparse_moe.experts.{i}.w3' + ] }, 'Qwen2VLDecoderLayer': { 'input_layernorm': @@ -120,7 +122,12 @@ def get_weight_scale(weight, q_group_size=-1): org_shape = weight.shape if q_group_size > 0: weight = weight.view(-1, q_group_size) - scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) + abs_weight = weight.abs() + abs_weight_amax = abs_weight.amax(dim=1, keepdim=True) + if abs_weight_amax.min().item() == 0: + print('weight.amax.min is zero, clamping weight.amax to 1e-4') + abs_weight_amax = abs_weight_amax.clamp(min=1e-4) + scale = abs_weight / abs_weight_amax scale = scale.view(org_shape) scale = scale.mean(0) return scale @@ -153,8 +160,13 @@ def smooth_ln_fcs(ln: torch.nn.Module, concat_w = torch.cat([fc.weight for fc in fcs], dim=0) w_scales = get_weight_scale(concat_w, group_size) + w_scales_pow = w_scales.pow(1 - alpha) + if w_scales_pow.min().item() == 0: + print('w_scales.pow(1 - alpha).min is zero, ' + 'clamping w_scales.pow(1 - alpha) to 1e-4') + w_scales_pow = w_scales_pow.clamp(min=1e-4) scales = (act_scales.pow(alpha) / - w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype) + w_scales_pow).clamp(min=1e-4).to(device).to(dtype) scales = scales / (scales[nonzero_positions].max() * scales[nonzero_positions].min()).sqrt() @@ -204,8 +216,13 @@ def smooth_fc_fcs(pre_fc: torch.nn.Module, concat_w = torch.cat([fc.weight for fc in fcs], dim=0) w_scales = get_weight_scale(concat_w, group_size) + w_scales_pow = w_scales.pow(1 - alpha) + if w_scales_pow.min().item() == 0: + print('w_scales.pow(1 - alpha).min is zero, ' + 'clamping w_scales.pow(1 - alpha) to 1e-4') + w_scales_pow = w_scales_pow.clamp(min=1e-4) scales = (act_scales.pow(alpha) / - w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype) + w_scales_pow).clamp(min=1e-4).to(device).to(dtype) scales = scales / (scales.max() * scales.min()).sqrt() # (for qwen&baichuan) pre_fc is packed QKV, only V needs to scale diff --git a/lmdeploy/lite/quantization/calibration.py b/lmdeploy/lite/quantization/calibration.py index e590f1a4eb..1df8f2c740 100644 --- a/lmdeploy/lite/quantization/calibration.py +++ b/lmdeploy/lite/quantization/calibration.py @@ -42,6 +42,9 @@ def __init__(self, tokenizer (PreTrainedTokenizer): Tokenizer of the given model. layer_type (Union[str, type]): Type of the layers to be observed. norm_type (Union[str, type]): Norm type used in the model. + batch_size (int): The batch size for running the calib samples. + Low GPU mem requires small batch_size. Large batch_size + reduces the calibration time while costs more VRAM. device (str, optional): Device where the model should run. Defaults to 'cuda'. """ @@ -290,9 +293,14 @@ def _search_module_scale(block, linears2scale: list, x, kwargs={}): org_sd = {k: v.cpu() for k, v in block.state_dict().items()} for ratio in range(0, n_grid): - ratio = ratio * 1 / n_grid - scales = (x_max.pow(ratio) / - w_mean.pow(1 - ratio)).clamp(min=1e-4).view(-1) + ratio = ratio / n_grid + w_mean_pow = w_mean.pow(1 - ratio) + if w_mean_pow.min().item() == 0: + print('w_mean.pow(1 - ratio).min is zero, ' + 'clamping w_mean.pow(1 - ratio) to 1e-4') + w_mean_pow = w_mean_pow.clamp(min=1e-4) + scales = (x_max.pow(ratio) / w_mean_pow).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() for fc in linears2scale: fc.weight.mul_(scales.view(1, -1).to(fc.weight.device)) diff --git a/lmdeploy/lite/utils/load.py b/lmdeploy/lite/utils/load.py index bfd306a743..ac4519371a 100644 --- a/lmdeploy/lite/utils/load.py +++ b/lmdeploy/lite/utils/load.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Literal + import torch from transformers import AutoConfig, AutoModelForCausalLM @@ -7,29 +9,42 @@ def load_hf_from_pretrained(pretrained_model_name_or_path, - dtype=torch.float16, - **kwargs): + dtype: Literal['float16', 'bfloat16', + 'auto'], **kwargs): - if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + if dtype == 'bfloat16' and not torch.cuda.is_bf16_supported(): raise RuntimeError('Your device does not supports bf16(bfloat16), ' 'please change to fp16(float16)') kwargs.pop('config', None) hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, - torch_dtype=dtype, trust_remote_code=True) # HACK hard code for qwen, other configs do not have the `fp16` attribute. - if dtype == torch.float16: - hf_config.fp16 = True - elif dtype == torch.bfloat16: - hf_config.bf16 = True + if hasattr(hf_config, 'fp16') or hasattr(hf_config, 'bf16'): + if dtype == 'bfloat16': + hf_config.bf16 = True + else: + hf_config.fp16 = True + + torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16) + if dtype == 'bfloat16': + torch_dtype = torch.bfloat16 + elif dtype == 'float16': + torch_dtype = torch.float16 + elif dtype == 'auto' and torch_dtype == torch.bfloat16: + print('Warning: we cast model to float16 to prevent OOM. ' + 'You may enforce it bfloat16 by `--dtype bfloat16`') + torch_dtype = torch.float16 with LoadNoInit(): # Load model model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path, config=hf_config, **kwargs) + pretrained_model_name_or_path, + config=hf_config, + torch_dtype=torch_dtype, + **kwargs) model.config.use_cache = False return model diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 4f04906f12..b54aa95330 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -293,8 +293,11 @@ def __post_init__(self): assert self.device_type in [ 'cuda', 'ascend', 'maca' ], (f'invalid device_type: {self.device_type}') - if self.quant_policy > 0 and self.device_type != 'cuda': - assert False, 'kv cache quantization only works for CUDA.' + if self.quant_policy > 0 and self.device_type not in [ + 'cuda', 'ascend' + ]: + assert False, \ + 'kv cache quantization only works for CUDA and ASCEND.' class ResponseType(enum.Enum): diff --git a/lmdeploy/model.py b/lmdeploy/model.py index a4355ea131..a0b0c8e09b 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -847,7 +847,7 @@ def __init__( - Only call one function at a time - Put the entire function call reply on one line" - Always add your sources when using search results to answer the user query\n\n""", # noqa - knowledge='Cutting Knowledge Date: December 2023\nToday Date: 23 Jul 2024\n\n', + knowledge='Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n', meta_instruction='You are a helpful assistant.', ipython='<|start_header_id|>ipython<|end_header_id|>\n\n', eoi='<|eot_id|>', @@ -1921,5 +1921,5 @@ def best_match_model(query: str) -> Optional[str]: for name, model in MODELS.module_dict.items(): if model.match(query): return model.match(query) - logger.warn(f'Did not find a chat template matching {query}.') + logger.warning(f'Did not find a chat template matching {query}.') return 'base' diff --git a/lmdeploy/pytorch/backends/attention.py b/lmdeploy/pytorch/backends/attention.py index 92a0befbf4..f0e60d86ac 100644 --- a/lmdeploy/pytorch/backends/attention.py +++ b/lmdeploy/pytorch/backends/attention.py @@ -34,6 +34,7 @@ def __init__( alibi: bool = None, sliding_window: int = None, logit_softcapping: float = None, + causal: bool = True, **kwargs, ) -> None: if scale is None: @@ -53,6 +54,7 @@ def __init__( self.alibi = alibi self.sliding_window = sliding_window self.logit_softcapping = logit_softcapping + self.causal = causal @abstractmethod def forward( @@ -82,6 +84,7 @@ def build( alibi: bool = False, sliding_window: int = None, logical_softcapping: float = None, + causal: bool = True, **kwargs, ) -> AttentionImpl[T]: """build.""" diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index ef538f7a3d..c8623666dc 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -12,7 +12,8 @@ class OpType(Enum): """Layer type enumerate.""" - Attention = auto() + PagedAttention = auto() + FlashAttention = auto() Linear = auto() RotaryEmbedding = auto() ApplyRotaryEmb = auto() diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index 8261b869f0..f9227497f2 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -41,6 +41,7 @@ def __init__( alibi: bool = False, sliding_window: int = None, logit_softcapping: float = None, + causal: bool = True, **kwargs, ): super().__init__( @@ -52,8 +53,10 @@ def __init__( alibi=alibi, sliding_window=sliding_window, logit_softcapping=logit_softcapping, + causal=causal, **kwargs, ) + assert not (alibi and not causal) from lmdeploy.pytorch.kernels.cuda import (alibi_paged_attention_fwd, fill_kv_cache, @@ -172,6 +175,7 @@ def forward( window_size=self.sliding_window, sm_scale=self.scale, logit_softcapping=self.logit_softcapping, + causal=self.causal, ) else: self.alibi_paged_attention_fwd( @@ -207,6 +211,7 @@ def build( alibi: bool = False, sliding_window: int = None, logical_softcapping: float = None, + causal: bool = True, **kwargs, ) -> TritonAttentionImpl: """build.""" @@ -218,4 +223,5 @@ def build( alibi=alibi, sliding_window=sliding_window, logical_softcapping=logical_softcapping, + causal=causal, **kwargs) diff --git a/lmdeploy/pytorch/backends/cuda/flash_attention.py b/lmdeploy/pytorch/backends/cuda/flash_attention.py new file mode 100644 index 0000000000..5d3925b744 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/flash_attention.py @@ -0,0 +1,101 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import Tensor + +from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl + + +class TritonFlashAttentionImpl(FlashAttentionImpl): + """triton flash attention implementation.""" + + def __init__( + self, + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logical_softcapping: float = None, + ): + if scale is None: + scale = 1.0 / (head_dim**0.5) + + if num_kv_heads is None: + num_kv_heads = num_heads + + if v_head_dim is None: + v_head_dim = head_dim + + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = scale + self.num_kv_heads = num_kv_heads + self.v_head_dim = v_head_dim + self.causal = causal + self.sliding_window = sliding_window + self.logical_softcapping = logical_softcapping + + from lmdeploy.pytorch.kernels.cuda import flash_attention_fwd + self.flash_attention_fwd = flash_attention_fwd + + def forward(self, + query: Tensor, + key: Tensor, + value: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_start_loc: Tensor, + kv_seqlens: Tensor, + max_q_seqlen: int = None): + """forward.""" + + q_shape = query.shape + o_shape = q_shape[:-1] + (self.v_head_dim, ) + out = query.new_empty(o_shape) + self.flash_attention_fwd( + query, + key, + value, + out, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_start_loc=kv_start_loc, + kv_seqlens=kv_seqlens, + max_seqlen=max_q_seqlen, + window_size=self.sliding_window, + sm_scale=self.scale, + logit_softcapping=self.logical_softcapping, + causal=self.causal, + kv_layout='shd', + ) + + return out + + +class TritonFlashAttentionBuilder(FlashAttentionBuilder): + """triton attention builder.""" + + @staticmethod + def build( + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logical_softcapping: float = None, + **kwargs, + ) -> FlashAttentionImpl: + """build.""" + return TritonFlashAttentionImpl( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + v_head_dim=v_head_dim, + causal=causal, + sliding_window=sliding_window, + logical_softcapping=logical_softcapping, + ) diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index d796f8e19f..bfe89dc63d 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -23,9 +23,12 @@ def get_name() -> str: @classmethod def get_layer_impl_builder(cls, layer_type: OpType): """get cuda layer builder.""" - if layer_type == OpType.Attention: + if layer_type == OpType.PagedAttention: from .attention import TritonAttentionBuilder return TritonAttentionBuilder + elif layer_type == OpType.FlashAttention: + from .flash_attention import TritonFlashAttentionBuilder + return TritonFlashAttentionBuilder elif layer_type == OpType.ApplyRotaryEmb: from .apply_rotary_emb import TritonApplyRotaryEmbBuilder return TritonApplyRotaryEmbBuilder @@ -125,30 +128,30 @@ def update_step_context(cls, step_context): quant_policy=step_context.kv_quant_policy, ) - cross_attn_metadata = None - fill_seqlens = None - if step_context.cross_attention_states is not None: - fill_seqlens = torch.zeros_like(q_seqlens) - for idx, state in enumerate(step_context.cross_attention_states): - if state is not None: - fill_seqlens[idx] = state.shape[-2] + cross_seqlens = step_context.cross_seqlens cross_kv_seqlens = step_context.cross_kv_seqlens - cross_kv_start_loc = None - cross_kv_flatten_size = None - if not step_context.is_decoding and cross_kv_seqlens is not None: - cross_kv_start_loc = cross_kv_seqlens.cumsum(0) - cross_kv_seqlens - cross_kv_flatten_size = cross_kv_seqlens.sum().item() - cross_attn_metadata = attn_meta_cls( - step_context.is_decoding, - step_context.block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seqlens, - kv_start_loc=cross_kv_start_loc, - kv_seqlens=cross_kv_seqlens, - kv_flatten_size=cross_kv_flatten_size, - fill_seqlens=fill_seqlens, - quant_policy=step_context.kv_quant_policy, - ) + cross_attn_metadata = None + if cross_seqlens is not None: + fill_seqlens = cross_seqlens + if fill_seqlens.sum().item() == 0: + fill_seqlens = None + cross_kv_start_loc = None + cross_kv_flatten_size = None + if not step_context.is_decoding and cross_kv_seqlens is not None: + cross_kv_start_loc = cross_kv_seqlens.cumsum( + 0) - cross_kv_seqlens + cross_kv_flatten_size = cross_kv_seqlens.sum().item() + cross_attn_metadata = attn_meta_cls( + step_context.is_decoding, + step_context.block_offsets, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_start_loc=cross_kv_start_loc, + kv_seqlens=cross_kv_seqlens, + kv_flatten_size=cross_kv_flatten_size, + fill_seqlens=fill_seqlens, + quant_policy=step_context.kv_quant_policy, + ) step_context.attn_metadata = attn_metadata step_context.cross_attn_metadata = cross_attn_metadata diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py index f9664f13ff..e3c5dc4d5e 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py @@ -33,10 +33,17 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, dlinfer.graph.config.enable_graph_mode = True self.patch_kernels_custom_op() self.patch_kvcache_static_shape() - self.model = torch.compile(self.model, - fullgraph=True, - dynamic=True, - backend='atbgraph') + if hasattr(self.model, 'language_model'): + self.model.language_model = torch.compile( + self.model.language_model, + fullgraph=True, + dynamic=True, + backend='atbgraph') + else: + self.model = torch.compile(self.model, + fullgraph=True, + dynamic=True, + backend='atbgraph') def check_enable_graph(self): """check enable graph.""" diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index b6f544510b..588558f0d5 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -1,5 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple +import itertools +import os +import re +from pathlib import Path +from typing import Dict, Tuple import torch @@ -11,6 +15,71 @@ logger = get_logger('lmdeploy') +class AscendKVQuantMeta: + has_set_value: bool = False + quant_meta: Dict = {} + + @classmethod + def set_value(cls, device: str, dtype: torch.dtype, record_file: str, + total_layers: int): + with open(record_file, 'r') as file: + data = file.read() + scale_offset_pairs = re.findall( + r'scale:\s*([\d\.\-]+)\s*offset:\s*(-?\d+)', data) + scale_offset_pairs = [(float(scale), float(offset)) + for scale, offset in scale_offset_pairs] + k_scales, v_scales, kv_scales = [], [], [] + k_zeros, v_zeros, kv_zeros = [], [], [] + if len(scale_offset_pairs) == total_layers: + for scale, offset in scale_offset_pairs: + k_scales.append( + torch.tensor([scale], device=device, dtype=dtype)) + v_scales.append( + torch.tensor([scale], device=device, dtype=dtype)) + kv_scales.append( + torch.tensor([scale, scale], device=device, dtype=dtype)) + k_zeros.append( + torch.tensor([offset], device=device, dtype=dtype)) + v_zeros.append( + torch.tensor([offset], device=device, dtype=dtype)) + kv_zeros.append( + torch.tensor([offset, offset], device=device, dtype=dtype)) + elif len(scale_offset_pairs) == total_layers * 2: + for i in range(total_layers): + scale_k, offset_k = scale_offset_pairs[2 * i] + scale_v, offset_v = scale_offset_pairs[2 * i + 1] + k_scales.append( + torch.tensor([scale_k], device=device, dtype=dtype)) + v_scales.append( + torch.tensor([scale_v], device=device, dtype=dtype)) + kv_scales.append( + torch.tensor([scale_k, scale_v], + device=device, + dtype=dtype)) + k_zeros.append( + torch.tensor([offset_k], device=device, dtype=dtype)) + v_zeros.append( + torch.tensor([offset_v], device=device, dtype=dtype)) + kv_zeros.append( + torch.tensor([offset_k, offset_v], + device=device, + dtype=dtype)) + else: + raise ValueError( + f'num of scale_offset_pairs({len(scale_offset_pairs)}) ' + f'must match num of total_layers({total_layers})') + + cls.quant_meta.update({ + 'k_scales': itertools.cycle(k_scales), + 'k_zeros': itertools.cycle(k_zeros), + 'v_scales': itertools.cycle(v_scales), + 'v_zeros': itertools.cycle(v_zeros), + 'kv_scales': itertools.cycle(kv_scales), + 'kv_zeros': itertools.cycle(kv_zeros) + }) + cls.has_set_value = True + + class AscendOpsBackend(DlinferOpsBackend): """ascend layer backend.""" enable_graph = False @@ -164,6 +233,21 @@ def get_total_slots(): .repeat_interleave(step_context.q_seqlens, 0) kv_seqlens = kv_seqlens_cpu + if not cls.enable_graph and step_context.kv_quant_policy == 8: + record_file = os.getenv('ASCEND_QUANT_RECORD_FILE') + assert record_file, 'please specify valid ASCEND_QUANT_RECORD_FILE' + path = Path(record_file) + is_path = path.is_absolute() or path.is_relative_to('/') + exists = path.exists() + if not (is_path and exists): + raise ValueError( + 'please specify valid ASCEND_QUANT_RECORD_FILE') + if not AscendKVQuantMeta.has_set_value: + total_layers = len(step_context.kv_caches) + AscendKVQuantMeta.set_value(step_context.block_offsets.device, + step_context.model_config.dtype, + record_file, total_layers) + attn_meta_cls = cls.get_attention_metadata_cls() attn_metadata = attn_meta_cls( step_context.is_decoding, @@ -177,6 +261,8 @@ def get_total_slots(): is_unpaged_prefill=is_unpaged_prefill, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, + quant_policy=step_context.kv_quant_policy, + quant_meta=AscendKVQuantMeta.quant_meta, ) step_context.attn_metadata = attn_metadata diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 0d666c9130..6b03403c84 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from dataclasses import dataclass -from typing import Optional, Sequence +from typing import Dict, Optional, Sequence from torch import Tensor @@ -15,6 +15,7 @@ class DlinferAttentionMetadata(AttentionMetadata): is_unpaged_prefill: Optional[bool] = None max_q_seq_len: int = 1 max_kv_seq_len: int = 1 + quant_meta: Dict = None class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]): @@ -30,8 +31,10 @@ def __init__( alibi: bool = None, sliding_window: int = None, logit_softcapping: float = None, + causal: bool = True, **kwargs, ): + assert causal super().__init__( num_heads, head_size, @@ -41,6 +44,7 @@ def __init__( alibi, sliding_window, logit_softcapping, + causal=causal, **kwargs, ) @@ -74,10 +78,37 @@ def forward( is_unpaged_prefill = attn_metadata.is_unpaged_prefill max_q_seq_len = attn_metadata.max_q_seq_len max_kv_seq_len = attn_metadata.max_kv_seq_len + quant_bits = attn_metadata.quant_policy + if attn_metadata.quant_meta is not None: + k_scales_zeros = [ + next(attn_metadata.quant_meta['k_scales']), + next(attn_metadata.quant_meta['k_zeros']) + ] if 'k_scales' in attn_metadata.quant_meta else [] + v_scales_zeros = [ + next(attn_metadata.quant_meta['v_scales']), + next(attn_metadata.quant_meta['v_zeros']) + ] if 'v_scales' in attn_metadata.quant_meta else [] + kv_scales = next( + attn_metadata.quant_meta['kv_scales'] + ) if 'kv_scales' in attn_metadata.quant_meta else None + kv_zeros = next( + attn_metadata.quant_meta['kv_zeros'] + ) if 'kv_zeros' in attn_metadata.quant_meta else None + else: + k_scales_zeros = [] + v_scales_zeros = [] + kv_scales = None + kv_zeros = None # fill kv cache - k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache, - kv_start_indices) + k_cache, v_cache = self.fill_kv_cache(key, + value, + k_cache, + v_cache, + kv_start_indices, + k_scales_zeros=k_scales_zeros, + v_scales_zeros=v_scales_zeros, + quant_bits=quant_bits) if inplace: attn_output = query[..., :self.v_head_size] @@ -103,6 +134,9 @@ def forward( block_size=block_size, attn_mask=attn_mask, is_unpaged_prefill=is_unpaged_prefill, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, ) return attn_output @@ -121,6 +155,7 @@ def build( alibi_scale: float = None, sliding_window: int = None, logical_softcapping: float = None, + causal: bool = True, **kwargs, ) -> DlinferAttentionImpl: """build.""" @@ -132,4 +167,5 @@ def build( alibi_scale=alibi_scale, sliding_window=sliding_window, logical_softcapping=logical_softcapping, + causal=causal, **kwargs) diff --git a/lmdeploy/pytorch/backends/dlinfer/flash_attention.py b/lmdeploy/pytorch/backends/dlinfer/flash_attention.py new file mode 100644 index 0000000000..d0d9ddbb26 --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/flash_attention.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import Tensor + +from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl + + +class DlinferFlashAttentionImpl(FlashAttentionImpl): + """dlinfer flash attention implementation.""" + + def __init__( + self, + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logical_softcapping: float = None, + ): + if scale is None: + scale = 1.0 / (head_dim**0.5) + if num_kv_heads is None: + num_kv_heads = num_heads + if v_head_dim is None: + v_head_dim = head_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = scale + self.num_kv_heads = num_kv_heads + self.v_head_dim = v_head_dim + self.causal = causal + self.sliding_window = sliding_window + self.logical_softcapping = logical_softcapping + from lmdeploy.pytorch.kernels.dlinfer import flash_attention_fwd + self.flash_attention_fwd = flash_attention_fwd + + def forward(self, + query: Tensor, + key: Tensor, + value: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_start_loc: Tensor, + kv_seqlens: Tensor, + max_q_seqlen: int = None): + """forward.""" + q_shape = query.shape + o_shape = q_shape[:-1] + (self.v_head_dim, ) + out = query.new_empty(o_shape) + self.flash_attention_fwd( + query, + key, + value, + out, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_start_loc=kv_start_loc, + kv_seqlens=kv_seqlens, + max_q_seqlen=max_q_seqlen, + window_size=self.sliding_window, + sm_scale=self.scale, + logit_softcapping=self.logical_softcapping, + causal=self.causal, + ) + return out + + +class DlinferFlashAttentionBuilder(FlashAttentionBuilder): + """dlinfer attention builder.""" + + @staticmethod + def build( + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logical_softcapping: float = None, + **kwargs, + ) -> FlashAttentionImpl: + """build.""" + return DlinferFlashAttentionImpl( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + v_head_dim=v_head_dim, + causal=causal, + sliding_window=sliding_window, + logical_softcapping=logical_softcapping, + ) diff --git a/lmdeploy/pytorch/backends/dlinfer/moe.py b/lmdeploy/pytorch/backends/dlinfer/moe.py index 6ada730fbe..ff986c5765 100644 --- a/lmdeploy/pytorch/backends/dlinfer/moe.py +++ b/lmdeploy/pytorch/backends/dlinfer/moe.py @@ -47,8 +47,8 @@ def forward(self, down_weights: torch.Tensor, expert_list: List[int] = None): """forward.""" - return fused_moe(hidden_states, self.top_k, topk_ids, topk_weights, - gate_up_weights, down_weights) + return fused_moe(hidden_states, gate_up_weights, down_weights, + topk_weights, topk_ids, self.top_k, self.renormalize) class DlinferFusedMoEBuilder(FusedMoEBuilder): diff --git a/lmdeploy/pytorch/backends/dlinfer/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/op_backend.py index 52a8830595..a0f04f34b1 100644 --- a/lmdeploy/pytorch/backends/dlinfer/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/op_backend.py @@ -22,9 +22,12 @@ def get_name() -> str: @classmethod def get_layer_impl_builder(cls, layer_type: OpType): """get dlinfer layer builder.""" - if layer_type == OpType.Attention: + if layer_type == OpType.PagedAttention: from .attention import DlinferAttentionBuilder return DlinferAttentionBuilder + elif layer_type == OpType.FlashAttention: + from .flash_attention import DlinferFlashAttentionBuilder + return DlinferFlashAttentionBuilder elif layer_type == OpType.ApplyRotaryEmb: from .apply_rotary_emb import DlinferApplyRotaryEmbBuilder return DlinferApplyRotaryEmbBuilder diff --git a/lmdeploy/pytorch/backends/flash_attention.py b/lmdeploy/pytorch/backends/flash_attention.py new file mode 100644 index 0000000000..bed3af8d68 --- /dev/null +++ b/lmdeploy/pytorch/backends/flash_attention.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + +from torch import Tensor + + +class FlashAttentionImpl(ABC): + """FlashAttention implementation.""" + + def forward(self, + query: Tensor, + key: Tensor, + value: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_start_loc: Tensor, + kv_seqlens: Tensor, + max_q_seqlen: int = None): + """forward.""" + raise NotImplementedError + + +class FlashAttentionBuilder(ABC): + """FlashAttention implementation builder.""" + + @staticmethod + @abstractmethod + def build( + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logical_softcapping: float = None, + **kwargs, + ) -> FlashAttentionImpl: + """build.""" + raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py index 9ab66b26a2..9347995e0b 100644 --- a/lmdeploy/pytorch/backends/graph_runner.py +++ b/lmdeploy/pytorch/backends/graph_runner.py @@ -46,3 +46,26 @@ def prepare_inputs_for_generation( inputs_embeds, context, ) + + def update_model_metas( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare inputs.""" + if hasattr(self.model, 'update_model_metas'): + return self.model.update_model_metas( + past_key_values, + inputs_embeds, + context, + ) + + return None + + def get_input_processor(self): + """get input processor.""" + if hasattr(self.model, 'get_input_processor'): + return self.model.get_input_processor() + else: + return None diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 7d72438224..bc95a32be6 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -1,277 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -from logging import Logger -from typing import List - -from lmdeploy.utils import get_logger - - -def _handle_exception(e: Exception, - mod_name: str, - logger: Logger, - message: str = None): - red_color = '\033[31m' - reset_color = '\033[0m' - if message is None: - message = 'Please ensure it has been installed correctly.' - logger.debug('Exception', exc_info=1) - logger.error(f'{type(e).__name__}: {e}') - logger.error(f'{red_color}' - f'<{mod_name}> test failed!\n' - f'{message}' - f'{reset_color}') - exit(1) +from .base import BaseChecker # noqa: F401 def check_env_deeplink(device_type: str): """check Deeplink environment.""" - try_import_deeplink(device_type) + from .deeplink import DeeplinkChecker + checker = DeeplinkChecker(device_type) + checker.handle() def try_import_deeplink(device_type: str): - """import dlinfer if specific device_type is set.""" - deeplink_device_type_list = [ - 'ascend', - 'npu', - 'maca', - ] - if device_type in deeplink_device_type_list: - logger = get_logger('lmdeploy') - try: - import dlinfer.framework.lmdeploy_ext # noqa: F401 - except Exception as e: - _handle_exception(e, 'PyTorch', logger) - - -def check_env_torch(): - """check PyTorch environment.""" - logger = get_logger('lmdeploy') - - try: - logger.debug('Checking environment.') - import torch - - a = torch.tensor([1, 2], device='cuda') - b = a.new_tensor([3, 4], device='cuda') - c = a + b - torch.testing.assert_close(c, a.new_tensor([4, 6])) - except Exception as e: - _handle_exception(e, 'PyTorch', logger) - - -MAX_TRITON_VERSION = '3.0.0' - - -def check_env_triton(device: str): - """check OpenAI Triton environment.""" - from packaging import version - logger = get_logger('lmdeploy') - - msg = ( - 'Please ensure that your device is functioning properly with .\n' # noqa: E501 - 'You can verify your environment by running ' - '`python -m lmdeploy.pytorch.check_env.triton_custom_add`.') - try: - logger.debug('Checking environment.') - import torch - import triton - triton_version = version.parse(triton.__version__) - if triton_version > version.parse(MAX_TRITON_VERSION): - logger.warning( - f'Engine has not been tested on triton>{MAX_TRITON_VERSION}.') - - from .triton_custom_add import custom_add - a = torch.tensor([1, 2], device='cuda') - b = a.new_tensor([3, 4], device='cuda') - c = custom_add(a, b) - torch.testing.assert_close(c, a + b) - except RuntimeError as e: - ptxas_error = 'device kernel image is invalid' - if len(e.args) > 0 and ptxas_error in e.args[0]: - msg = ( - 'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n' # noqa: E501 - 'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209' # noqa: E501 - ' or reinstall the driver.') - _handle_exception(e, 'Triton', logger, msg) - except Exception as e: - _handle_exception(e, 'Triton', logger, msg) - - if device == 'cuda': - device_cap = torch.cuda.get_device_capability() - TRITON_VER_231 = version.parse('2.3.1') - - if device_cap[0] <= 7: - if triton_version <= TRITON_VER_231: - err = RuntimeError( - 'Attention triton kernel does not fully support ' - 'triton<3.0.0 on device with capability<8. ' - 'Please upgrade your triton version.') - _handle_exception(err, 'Triton', logger) - - -def check_env(device_type: str): - """check all environment.""" - logger = get_logger('lmdeploy') - logger.info('Checking environment for PyTorch Engine.') + """check Deeplink environment.""" check_env_deeplink(device_type) - check_env_torch() - if device_type == 'cuda': - check_env_triton('cuda') - - -MIN_TRANSFORMERS_VERSION = '4.33.0' -MAX_TRANSFORMERS_VERSION = '4.44.1' - - -def check_awq(hf_config, device_type): - """check awq support.""" - logger = get_logger('lmdeploy') - if device_type == 'cuda': - quantization_config = getattr(hf_config, 'quantization_config', dict()) - quant_method = quantization_config.get('quant_method', None) - if quant_method != 'awq': - return - try: - import awq # noqa - except Exception as e: - _handle_exception(e, 'autoawq', logger) - - try: - import awq_ext # noqa - except Exception: - logger.debug('Exception:', exc_info=1) - logger.warning('Failed to import `awq_ext`. ' - 'Try reinstall it from source: ' - 'https://github.com/casper-hansen/AutoAWQ_kernels') - - -def check_transformers_version(model_path: str, - trust_remote_code: bool = True, - dtype: str = 'auto', - device_type: str = 'cuda'): - """check transformers version.""" - from packaging import version - logger = get_logger('lmdeploy') - - def __check_transformers_version(): - """check transformers version.""" - logger.debug('Checking version.') - trans_version = None - try: - import transformers - trans_version = version.parse(transformers.__version__) - min_version = version.parse(MIN_TRANSFORMERS_VERSION) - max_version = version.parse(MAX_TRANSFORMERS_VERSION) - if trans_version < min_version or trans_version > max_version: - logger.warning('LMDeploy requires transformers version: ' - f'[{MIN_TRANSFORMERS_VERSION} ~ ' - f'{MAX_TRANSFORMERS_VERSION}], ' - 'but found version: ' - f'{transformers.__version__}') - except Exception as e: - _handle_exception(e, 'transformers', logger) - return transformers, trans_version - - def __check_config(trans_version): - """check config.""" - logger.debug('Checking AutoConfig.from_pretrained.') - try: - from transformers import AutoConfig - config = AutoConfig.from_pretrained( - model_path, trust_remote_code=trust_remote_code) - except Exception as e: - message = ( - f'Load model config with transformers=={trans_version}' - ' failed. ' - 'Please make sure model can be loaded with transformers API.') - _handle_exception(e, 'transformers', logger, message=message) - return config - - def __check_model_transformers_version(config, trans_version): - """check model transformers version.""" - logger.debug('Checking required transformers version.') - try: - model_trans_version = getattr(config, 'transformers_version', None) - if model_trans_version is not None: - model_trans_version = version.parse(model_trans_version) - assert trans_version >= model_trans_version, \ - 'Version mismatch.' - except Exception as e: - message = (f'model `{model_path}` requires ' - f'transformers version {model_trans_version} ' - f'but transformers {trans_version} is installed.') - _handle_exception(e, 'transformers', logger, message=message) - - def __check_model_dtype_support(config, device_type): - """Checking model dtype support.""" - logger.debug('Checking dtype support.') - - import torch - - from lmdeploy.pytorch.config import ModelConfig - from lmdeploy.utils import is_bf16_supported - - try: - model_config = ModelConfig.from_hf_config(config, - model_path=model_path, - dtype=dtype) - if model_config.dtype == torch.bfloat16: - assert is_bf16_supported(device_type), ( - 'bf16 is not supported on your device') - except AssertionError as e: - message = ( - f'Your device does not support `{model_config.dtype}`. ' - 'You can set `dtype` to float16 in PyTorchEngineConfig or ' - '`--dtype float16` to api_server.\n' - 'Note that this might have negative effect!') - _handle_exception(e, 'Model', logger, message=message) - except Exception as e: - message = (f'Checking failed with error {e}', - 'Please send issue to LMDeploy with error logs.') - _handle_exception(e, 'Model', logger, message=message) - - return model_config - - _, trans_version = __check_transformers_version() - config = __check_config(trans_version) - __check_model_transformers_version(config, trans_version) - __check_model_dtype_support(config, device_type) - check_awq(config, device_type) - - -def check_model(model_path: str, - trust_remote_code: bool = True, - dtype: str = 'auto', - device_type: str = 'cuda'): - """check model requirements.""" - logger = get_logger('lmdeploy') - logger.info('Checking model.') - check_transformers_version(model_path, trust_remote_code, dtype, - device_type) - - -def check_adapter(path: str): - """check adapter.""" - logger = get_logger('lmdeploy') - logger.debug(f'Checking : {path}.') - - try: - from peft import PeftConfig - PeftConfig.from_pretrained(path) - except Exception as e: - message = ('Please make sure the adapter can be loaded with ' - '`peft.PeftConfig.from_pretrained`\n') - err_msg = '' if len(e.args) == 0 else e.args[0] - if 'got an unexpected keyword argument' in err_msg: - message += ('Or try remove all unexpected keywords ' - 'in `adapter_config.json`.') - _handle_exception(e, 'Model', logger, message=message) - - -def check_adapters(adapter_paths: List[str]): - """check adapters.""" - if len(adapter_paths) <= 0: - return - logger = get_logger('lmdeploy') - logger.info('Checking adapters.') - for path in adapter_paths: - check_adapter(path) diff --git a/lmdeploy/pytorch/check_env/adapter.py b/lmdeploy/pytorch/check_env/adapter.py new file mode 100644 index 0000000000..bcaf5fd0e3 --- /dev/null +++ b/lmdeploy/pytorch/check_env/adapter.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseChecker + + +class AdapterChecker(BaseChecker): + """check adapter is available.""" + + def __init__(self, adapter_path: str, logger=None): + super().__init__(logger) + self.adapter_path = adapter_path + + def check(self): + """check.""" + path = self.adapter_path + + try: + import peft # noqa: F401 + except Exception as e: + self.log_and_exit(e, 'Adapter', message='Failed to import peft.') + + try: + from peft import PeftConfig + PeftConfig.from_pretrained(path) + except Exception as e: + message = ('Please make sure the adapter can be loaded with ' + '`peft.PeftConfig.from_pretrained`\n') + err_msg = '' if len(e.args) == 0 else e.args[0] + if 'got an unexpected keyword argument' in err_msg: + message += ('Or try remove all unexpected keywords ' + 'in `adapter_config.json`.') + self.log_and_exit(e, 'Adapter', message=message) diff --git a/lmdeploy/pytorch/check_env/base.py b/lmdeploy/pytorch/check_env/base.py new file mode 100644 index 0000000000..ed5e5a600f --- /dev/null +++ b/lmdeploy/pytorch/check_env/base.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from logging import Logger +from typing import List + +from lmdeploy.utils import get_logger + +RED_COLOR = '\033[31m' +RESET_COLOR = '\033[0m' + + +def _red_text(text: str): + """red text.""" + return f'{RED_COLOR}{text}{RESET_COLOR}' + + +class BaseChecker: + """base checker.""" + + def __init__(self, logger: Logger = None): + if logger is None: + logger = get_logger('lmdeploy') + self.logger = logger + self._is_passed = False + self._required_checker: List[BaseChecker] = list() + + def get_logger(self): + """get logger.""" + return self.logger + + def register_required_checker(self, checker: 'BaseChecker'): + """register_required.""" + self._required_checker.append(checker) + + def handle(self): + """handle check.""" + is_passed = getattr(self, '_is_passed', False) + if not is_passed: + checker_name = type(self).__name__ + self.logger.debug(f'Checking <{checker_name}>:') + for checker in self._required_checker: + checker.handle() + self.check() + self.is_passed = True + + def log_and_exit(self, + e: Exception = None, + mod_name: str = None, + message: str = None): + logger = self.logger + if mod_name is None: + mod_name = type(self).__name__ + if message is None: + message = 'Please check your environment.' + logger.debug('Exception', exc_info=1) + if e is not None: + logger.error(f'{type(e).__name__}: {e}') + logger.error(f'<{mod_name}> check failed!\n{_red_text(message)}') + exit(1) + + def check(self): + """check.""" + raise NotImplementedError('check not implemented.') diff --git a/lmdeploy/pytorch/check_env/deeplink.py b/lmdeploy/pytorch/check_env/deeplink.py new file mode 100644 index 0000000000..74ab5a7b87 --- /dev/null +++ b/lmdeploy/pytorch/check_env/deeplink.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseChecker + +deeplink_device_type_list = [ + 'ascend', + 'npu', + 'maca', +] + + +class DeeplinkChecker(BaseChecker): + """check pytorch is available.""" + + def __init__(self, device_type: str, logger=None) -> None: + super().__init__(logger=logger) + self.device_type = device_type + + def check(self): + """check.""" + device_type = self.device_type + if device_type in deeplink_device_type_list: + try: + import dlinfer.framework.lmdeploy_ext # noqa: F401 + except Exception as e: + self.log_and_exit(e, 'dlinfer', 'dlinfer is not available.') diff --git a/lmdeploy/pytorch/check_env/model.py b/lmdeploy/pytorch/check_env/model.py new file mode 100644 index 0000000000..4b721e50e2 --- /dev/null +++ b/lmdeploy/pytorch/check_env/model.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from packaging import version + +from .base import BaseChecker + + +class ModelChecker(BaseChecker): + """check model is available.""" + + def __init__(self, + model_path: str, + trust_remote_code: bool, + dtype: str, + device_type: str, + logger=None) -> None: + super().__init__(logger=logger) + self.model_path = model_path + self.trust_remote_code = trust_remote_code + self.device_type = device_type + self.dtype = dtype + + def check_config(self, trans_version): + """check config.""" + model_path = self.model_path + trust_remote_code = self.trust_remote_code + try: + from transformers import AutoConfig + config = AutoConfig.from_pretrained( + model_path, trust_remote_code=trust_remote_code) + except Exception as e: + message = ( + f'Load model config with transformers=={trans_version}' + ' failed. ' + 'Please make sure model can be loaded with transformers API.') + self.log_and_exit(e, 'transformers', message=message) + return config + + def check_trans_version(self, config, trans_version): + """check transformers version.""" + model_path = self.model_path + try: + model_trans_version = getattr(config, 'transformers_version', None) + if model_trans_version is not None: + model_trans_version = version.parse(model_trans_version) + assert trans_version >= model_trans_version, ( + 'Version mismatch.') + except Exception as e: + message = (f'model `{model_path}` requires ' + f'transformers version {model_trans_version} ' + f'but transformers {trans_version} is installed.') + self.log_and_exit(e, 'transformers', message=message) + + def check_dtype(self, config): + """check dtype.""" + logger = self.get_logger() + model_path = self.model_path + device_type = self.device_type + dtype = self.dtype + try: + import torch + + from lmdeploy.pytorch.config import ModelConfig + from lmdeploy.utils import is_bf16_supported + model_config = ModelConfig.from_hf_config(config, + model_path=model_path, + dtype=dtype) + if model_config.dtype == torch.bfloat16: + if not is_bf16_supported(device_type): + logger.warning('Device does not support bfloat16.') + except Exception as e: + message = (f'Checking failed with error {e}', + 'Please send issue to LMDeploy with error logs.') + self.log_and_exit(e, 'Model', message=message) + + def check_awq(self, config): + """check awq.""" + logger = self.get_logger() + device_type = self.device_type + if device_type != 'cuda': + return + + quantization_config = getattr(config, 'quantization_config', dict()) + quant_method = quantization_config.get('quant_method', None) + if quant_method != 'awq': + return + try: + import awq # noqa + except Exception as e: + self.log_and_exit(e, 'autoawq', logger) + + try: + import awq_ext # noqa + except Exception as e: + logger.debug('Exception:', exc_info=1) + self.log_and_exit( + e, + 'awq_ext', + message='Failed to import `awq_ext`. ' + 'Try reinstall it from source: ' + 'https://github.com/casper-hansen/AutoAWQ_kernels') + + def check(self): + """check.""" + import transformers + trans_version = version.parse(transformers.__version__) + + # config + config = self.check_config(trans_version) + + # transformers version + self.check_trans_version(config, trans_version) + + # dtype check + self.check_dtype(config) + + # awq + self.check_awq(config) diff --git a/lmdeploy/pytorch/check_env/torch.py b/lmdeploy/pytorch/check_env/torch.py new file mode 100644 index 0000000000..14b24e04a0 --- /dev/null +++ b/lmdeploy/pytorch/check_env/torch.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseChecker + + +class TorchChecker(BaseChecker): + """check pytorch is available.""" + + def __init__(self, device: str = 'cuda', logger=None) -> None: + super().__init__(logger=logger) + self.device = device + + def check(self): + """check.""" + try: + import torch + a = torch.tensor([1, 2], device=self.device) + b = a.new_tensor([3, 4], device=self.device) + c = a + b + torch.testing.assert_close(c, a.new_tensor([4, 6])) + except Exception as e: + self.log_and_exit(e, 'PyTorch', 'PyTorch is not available.') diff --git a/lmdeploy/pytorch/check_env/transformers.py b/lmdeploy/pytorch/check_env/transformers.py new file mode 100644 index 0000000000..9d97cd6dca --- /dev/null +++ b/lmdeploy/pytorch/check_env/transformers.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from packaging import version + +from .base import BaseChecker + +MIN_TRANSFORMERS_VERSION = '4.33.0' +MAX_TRANSFORMERS_VERSION = '4.46.1' + + +class TransformersChecker(BaseChecker): + """check transformers is available.""" + + def check(self): + """check.""" + import transformers + logger = self.get_logger() + try: + trans_version = version.parse(transformers.__version__) + min_version = version.parse(MIN_TRANSFORMERS_VERSION) + max_version = version.parse(MAX_TRANSFORMERS_VERSION) + if trans_version < min_version or trans_version > max_version: + logger.warning('LMDeploy requires transformers version: ' + f'[{MIN_TRANSFORMERS_VERSION} ~ ' + f'{MAX_TRANSFORMERS_VERSION}], ' + 'but found version: ' + f'{transformers.__version__}') + except Exception as e: + self.log_and_exit(e, 'transformers', + 'transformers is not available.') diff --git a/lmdeploy/pytorch/check_env/triton.py b/lmdeploy/pytorch/check_env/triton.py new file mode 100644 index 0000000000..4cc58c5492 --- /dev/null +++ b/lmdeploy/pytorch/check_env/triton.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from packaging import version + +from .base import BaseChecker + +MAX_TRITON_VERSION = '3.1.0' +MIN_TRITON_VERSION = '3.0.0' + + +class TritonChecker(BaseChecker): + """check triton is available.""" + + def check_version(self): + """check version.""" + logger = self.get_logger() + + # version check + import triton + max_version = version.parse(MAX_TRITON_VERSION) + min_version = version.parse(MIN_TRITON_VERSION) + triton_version = version.parse(triton.__version__) + + if triton_version > max_version: + logger.warning('PytorchEngine has not been tested on ' + f'triton>{MAX_TRITON_VERSION}.') + if triton_version < min_version: + msg = (f'triton>={MIN_TRITON_VERSION} is required. ' + f'Found triton=={triton_version}') + self.log_and_exit(mod_name='Triton', message=msg) + + def check(self): + """check.""" + logger = self.get_logger() + + msg = ( + 'Please ensure that your device is functioning properly with .\n' # noqa: E501 + 'You can verify your environment by running ' + '`python -m lmdeploy.pytorch.check_env.triton_custom_add`.') + try: + logger.debug('Checking environment.') + import torch + + from .triton_custom_add import custom_add + a = torch.tensor([1, 2], device='cuda') + b = a.new_tensor([3, 4], device='cuda') + c = custom_add(a, b) + torch.testing.assert_close(c, a + b) + except RuntimeError as e: + ptxas_error = 'device kernel image is invalid' + if len(e.args) > 0 and ptxas_error in e.args[0]: + msg = ( + 'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n' # noqa: E501 + 'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209' # noqa: E501 + ' or reinstall the driver.') + self.log_and_exit(e, 'Triton', msg) + except Exception as e: + self.log_and_exit(e, 'Triton', msg) + + # version check + self.check_version() diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index c350f4b4cf..7783afd970 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -26,6 +26,10 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str): return config torch_dtype = getattr(config.hf_config, 'torch_dtype', None) + # deal with case when torch_dtype is not string but torch.dtype + if isinstance(torch_dtype, torch.dtype): + torch_dtype = str(torch_dtype).split('.')[1] + if torch_dtype is None: _dtype = 'float16' if dtype == 'auto' else dtype logger.warning('Model config does not have `torch_dtype`,' @@ -37,8 +41,8 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str): # change to user specified data type if it is not 'auto' if dtype == 'auto': torch_dtype = torch_dtype if torch_dtype in [ - torch.float16, torch.bfloat16 - ] else torch.float16 + 'float16', 'bfloat16' + ] else 'float16' else: torch_dtype = dtype config.dtype = eval(f'torch.{torch_dtype}') @@ -77,6 +81,7 @@ class CacheConfig: max_prefill_token_num: int = 4096 enable_prefix_caching: bool = False quant_policy: Literal[0, 4, 8] = 0 + device_type: str = 'cuda' def __post_init__(self): """post init.""" @@ -103,7 +108,6 @@ class ModelConfig: v_head_dim: int = None sliding_window: int = -1 dtype: torch.dtype = torch.float16 - multi_query_attention: bool = False vocab_size: int = 40000 hf_config: Any = None cogvlm_style: bool = False @@ -117,7 +121,8 @@ def get_head_size(self): def from_pretrained(cls, pretrained_model_name_or_path: str, trust_remote_code: bool = True, - dtype: str = 'auto'): + dtype: str = 'auto', + tp: int = 1): """Instantiate one of the configuration classes of the library from a pretrained model configuration. @@ -137,17 +142,21 @@ def from_pretrained(cls, pretrained_model_name_or_path) return cls.from_hf_config(hf_config, pretrained_model_name_or_path, - dtype=dtype) + dtype=dtype, + tp=tp) @classmethod def from_hf_config(cls, hf_config: Any, model_path: str = None, - dtype: str = 'auto'): + dtype: str = 'auto', + tp: int = 1): """from huggingface config.""" from lmdeploy.pytorch.configurations import AutoModelConfigBuilder - model_config = AutoModelConfigBuilder.build(hf_config, model_path) + model_config = AutoModelConfigBuilder.build(hf_config, + model_path, + tp=tp) if model_config.k_head_dim is None: assert model_config.head_dim is not None @@ -156,6 +165,13 @@ def from_hf_config(cls, assert model_config.head_dim is not None model_config.v_head_dim = model_config.head_dim + # check for tp + assert model_config.num_attention_heads % tp == 0 + if model_config.num_key_value_heads >= tp: + assert model_config.num_key_value_heads % tp == 0 + else: + assert tp % model_config.num_key_value_heads == 0 + # should after setting `hf_config` and `model_arch` attributes model_config = _update_torch_dtype(model_config, dtype) diff --git a/lmdeploy/pytorch/configurations/builder.py b/lmdeploy/pytorch/configurations/builder.py index 89bf51ca46..bafa78ba02 100644 --- a/lmdeploy/pytorch/configurations/builder.py +++ b/lmdeploy/pytorch/configurations/builder.py @@ -27,7 +27,7 @@ def condition(cls, hf_config): f'`condition` of {cls.__name__} not implemented.') @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" from .default import DefaultModelConfigBuilder @@ -46,8 +46,21 @@ def build(cls, hf_config, model_path: str = None): logger.debug(f'build model config with {valid_builder.__name__}') - cfg = valid_builder.build(hf_config, model_path) + cfg = valid_builder.build(hf_config, model_path, **kwargs) if cfg.hf_config is None: cfg.hf_config = hf_config return cfg + + @classmethod + def update_num_kv_heads(cls, hf_config, tp, num_key_value_heads): + """update num kv heads.""" + # update num_kv_heads for tp mode + if tp > 1 and tp > num_key_value_heads: + assert tp % num_key_value_heads == 0 + n_replicate = tp // num_key_value_heads + hf_config.num_replicate_key_value_heads = n_replicate + num_key_value_heads = tp + + hf_config.num_key_value_heads = num_key_value_heads + return num_key_value_heads diff --git a/lmdeploy/pytorch/configurations/chatglm.py b/lmdeploy/pytorch/configurations/chatglm.py index 7911c985d5..fbf4d48281 100644 --- a/lmdeploy/pytorch/configurations/chatglm.py +++ b/lmdeploy/pytorch/configurations/chatglm.py @@ -12,16 +12,27 @@ def condition(cls, hf_config): return hf_config.model_type == 'chatglm' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" head_dim = hf_config.hidden_size // hf_config.num_attention_heads bos_token_id = hf_config.bos_token_id if bos_token_id is None: bos_token_id = hf_config.pad_token_id + + if hf_config.multi_query_attention: + num_key_value_heads = hf_config.multi_query_group_num + else: + num_key_value_heads = hf_config.num_attention_heads + + tp = kwargs.get('tp', 1) + # update num_kv_heads for tp mode + num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, + num_key_value_heads) + cfg = ModelConfig(hidden_size=hf_config.hidden_size, num_layers=hf_config.num_layers, num_attention_heads=hf_config.num_attention_heads, - num_key_value_heads=hf_config.multi_query_group_num, + num_key_value_heads=num_key_value_heads, bos_token_id=bos_token_id, eos_token_id=hf_config.eos_token_id, head_dim=head_dim, diff --git a/lmdeploy/pytorch/configurations/cogvlm.py b/lmdeploy/pytorch/configurations/cogvlm.py index b24d92d794..4736dfee69 100644 --- a/lmdeploy/pytorch/configurations/cogvlm.py +++ b/lmdeploy/pytorch/configurations/cogvlm.py @@ -12,12 +12,15 @@ def condition(cls, hf_config): return model_arch == 'CogVLMForCausalLM' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" from lmdeploy.utils import is_bf16_supported - cfg = DefaultModelConfigBuilder.build(hf_config) if getattr(hf_config, 'num_multi_query_heads', None): - cfg.num_key_value_heads = hf_config.num_multi_query_heads + hf_config.num_key_value_heads = hf_config.num_multi_query_heads + else: + hf_config.num_key_value_heads = hf_config.num_attention_heads + + cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs) cfg.cogvlm_style = True torch_dtype = 'bfloat16' if is_bf16_supported() else 'float16' hf_config.torch_dtype = torch_dtype diff --git a/lmdeploy/pytorch/configurations/dbrx.py b/lmdeploy/pytorch/configurations/dbrx.py index 2c8128a5a6..dcc1222b0d 100644 --- a/lmdeploy/pytorch/configurations/dbrx.py +++ b/lmdeploy/pytorch/configurations/dbrx.py @@ -12,7 +12,7 @@ def condition(cls, hf_config): return hf_config.model_type == 'dbrx' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" hidden_size = hf_config.d_model num_heads = hf_config.n_heads diff --git a/lmdeploy/pytorch/configurations/deepseek_v2.py b/lmdeploy/pytorch/configurations/deepseek_v2.py index 37aa4b0d69..d1f0844ad5 100644 --- a/lmdeploy/pytorch/configurations/deepseek_v2.py +++ b/lmdeploy/pytorch/configurations/deepseek_v2.py @@ -12,13 +12,19 @@ def condition(cls, hf_config): return hf_config.model_type == 'deepseek_v2' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" head_dim = (hf_config.kv_lora_rank + hf_config.qk_rope_head_dim) k_head_dim = head_dim v_head_dim = 0 num_attention_heads = hf_config.num_attention_heads + # multi query attn num_key_value_heads = 1 + tp = kwargs.get('tp', 1) + # update num_kv_heads for tp mode + num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, + num_key_value_heads) + return ModelConfig(hidden_size=hf_config.hidden_size, num_layers=hf_config.num_hidden_layers, num_attention_heads=num_attention_heads, @@ -28,5 +34,4 @@ def build(cls, hf_config, model_path: str = None): head_dim=head_dim, k_head_dim=k_head_dim, v_head_dim=v_head_dim, - vocab_size=hf_config.vocab_size, - multi_query_attention=True) + vocab_size=hf_config.vocab_size) diff --git a/lmdeploy/pytorch/configurations/default.py b/lmdeploy/pytorch/configurations/default.py index 1f84b810ea..d1337a241e 100644 --- a/lmdeploy/pytorch/configurations/default.py +++ b/lmdeploy/pytorch/configurations/default.py @@ -12,7 +12,7 @@ def condition(cls, hf_config): return True @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" head_dim = hf_config.hidden_size // hf_config.num_attention_heads num_attention_heads = hf_config.num_attention_heads @@ -23,6 +23,11 @@ def build(cls, hf_config, model_path: str = None): if use_sliding_window: sliding_window = getattr(hf_config, 'sliding_window', sliding_window) or -1 + tp = kwargs.get('tp', 1) + # update num_kv_heads for tp mode + num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, + num_key_value_heads) + return ModelConfig( hidden_size=hf_config.hidden_size, num_layers=hf_config.num_hidden_layers, diff --git a/lmdeploy/pytorch/configurations/falcon.py b/lmdeploy/pytorch/configurations/falcon.py index db4d00e397..a4c8d4d44f 100644 --- a/lmdeploy/pytorch/configurations/falcon.py +++ b/lmdeploy/pytorch/configurations/falcon.py @@ -12,7 +12,7 @@ def condition(cls, hf_config): return hf_config.model_type == 'falcon' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build falcon.""" num_attention_heads = hf_config.num_attention_heads if hf_config.new_decoder_architecture: @@ -24,6 +24,12 @@ def build(cls, hf_config, model_path: str = None): else: # rw-1b, MHA kv_head = num_attention_heads + + tp = kwargs.get('tp', 1) + # update num_kv_heads for tp mode + kv_head = cls.update_num_kv_heads(hf_config, tp, kv_head) + hf_config.num_kv_heads = kv_head + head_dim = hf_config.hidden_size // num_attention_heads return ModelConfig( hidden_size=hf_config.hidden_size, @@ -33,6 +39,5 @@ def build(cls, hf_config, model_path: str = None): bos_token_id=hf_config.bos_token_id, eos_token_id=hf_config.eos_token_id, head_dim=head_dim, - multi_query_attention=hf_config.multi_query, vocab_size=hf_config.vocab_size, ) diff --git a/lmdeploy/pytorch/configurations/gemma.py b/lmdeploy/pytorch/configurations/gemma.py index 338eaee6d0..d49fdbd96c 100644 --- a/lmdeploy/pytorch/configurations/gemma.py +++ b/lmdeploy/pytorch/configurations/gemma.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from lmdeploy.pytorch.config import ModelConfig - from .builder import AutoModelConfigBuilder +from .default import DefaultModelConfigBuilder class GemmaModelConfigBuilder(AutoModelConfigBuilder): @@ -12,13 +11,8 @@ def condition(cls, hf_config): return hf_config.model_type in ['gemma', 'gemma2'] @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build gemma.""" - return ModelConfig(hidden_size=hf_config.hidden_size, - num_layers=hf_config.num_hidden_layers, - num_attention_heads=hf_config.num_attention_heads, - num_key_value_heads=hf_config.num_key_value_heads, - bos_token_id=hf_config.bos_token_id, - eos_token_id=hf_config.eos_token_id, - head_dim=hf_config.head_dim, - vocab_size=hf_config.vocab_size) + cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs) + cfg.head_dim = hf_config.head_dim + return cfg diff --git a/lmdeploy/pytorch/configurations/internvl.py b/lmdeploy/pytorch/configurations/internvl.py index 76b4187c5f..ffff0a0e15 100644 --- a/lmdeploy/pytorch/configurations/internvl.py +++ b/lmdeploy/pytorch/configurations/internvl.py @@ -11,8 +11,9 @@ def condition(cls, hf_config): return hf_config.architectures[0] == 'InternVLChatModel' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build llava hf.""" - cfg = DefaultModelConfigBuilder.build(hf_config.llm_config) + cfg = DefaultModelConfigBuilder.build(hf_config.llm_config, model_path, + **kwargs) cfg.hf_config = hf_config return cfg diff --git a/lmdeploy/pytorch/configurations/llava.py b/lmdeploy/pytorch/configurations/llava.py deleted file mode 100644 index aaeeeeadfe..0000000000 --- a/lmdeploy/pytorch/configurations/llava.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .builder import AutoModelConfigBuilder -from .default import DefaultModelConfigBuilder - - -class LlavaModelConfigBuilder(AutoModelConfigBuilder): - - @classmethod - def condition(cls, hf_config): - """config.""" - return hf_config.architectures[0] in [ - 'LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM' - ] - - @classmethod - def build(cls, hf_config, model_path: str = None): - """build.""" - arch = hf_config.architectures[0] - if arch in ['LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM']: - from llava.model.language_model.llava_llama import LlavaConfig - - # reload hf_config due to model_type='llava' is already - # registered in transformers - hf_config = LlavaConfig.from_pretrained(model_path) - cfg = DefaultModelConfigBuilder.build(hf_config) - return cfg diff --git a/lmdeploy/pytorch/configurations/llava_hf.py b/lmdeploy/pytorch/configurations/llava_hf.py index 4cc007e313..5334eaec25 100644 --- a/lmdeploy/pytorch/configurations/llava_hf.py +++ b/lmdeploy/pytorch/configurations/llava_hf.py @@ -15,7 +15,7 @@ def condition(cls, hf_config): ] @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build llava hf.""" text_config = hf_config.text_config hidden_size = getattr(text_config, 'hidden_size', 4096) diff --git a/lmdeploy/pytorch/configurations/minicpm3.py b/lmdeploy/pytorch/configurations/minicpm3.py index 7cde51bd42..857673aab3 100644 --- a/lmdeploy/pytorch/configurations/minicpm3.py +++ b/lmdeploy/pytorch/configurations/minicpm3.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from lmdeploy.pytorch.config import ModelConfig from .builder import AutoModelConfigBuilder +from .default import DefaultModelConfigBuilder class MiniCPM3ModelConfigBuilder(AutoModelConfigBuilder): @@ -12,21 +12,13 @@ def condition(cls, hf_config): return hf_config.architectures[0] in ['MiniCPM3ForCausalLM'] @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" head_dim = (hf_config.qk_nope_head_dim + hf_config.qk_rope_head_dim) - k_head_dim = head_dim - v_head_dim = head_dim - num_attention_heads = hf_config.num_attention_heads - num_key_value_heads = hf_config.num_key_value_heads - return ModelConfig(hidden_size=hf_config.hidden_size, - num_layers=hf_config.num_hidden_layers, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - bos_token_id=hf_config.bos_token_id, - eos_token_id=hf_config.eos_token_id, - head_dim=head_dim, - k_head_dim=k_head_dim, - v_head_dim=v_head_dim, - vocab_size=hf_config.vocab_size, - multi_query_attention=False) + + cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs) + cfg.head_dim = head_dim + cfg.k_head_dim = head_dim + cfg.v_head_dim = head_dim + + return cfg diff --git a/lmdeploy/pytorch/configurations/mllama.py b/lmdeploy/pytorch/configurations/mllama.py index 2383c92c50..e56e0fbed4 100644 --- a/lmdeploy/pytorch/configurations/mllama.py +++ b/lmdeploy/pytorch/configurations/mllama.py @@ -11,8 +11,9 @@ def condition(cls, hf_config): return hf_config.architectures[0] == 'MllamaForConditionalGeneration' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build llava hf.""" - cfg = DefaultModelConfigBuilder.build(hf_config.text_config) + cfg = DefaultModelConfigBuilder.build(hf_config.text_config, + model_path, **kwargs) cfg.hf_config = hf_config return cfg diff --git a/lmdeploy/pytorch/configurations/qwen.py b/lmdeploy/pytorch/configurations/qwen.py index 05ac77c1d1..eda726de43 100644 --- a/lmdeploy/pytorch/configurations/qwen.py +++ b/lmdeploy/pytorch/configurations/qwen.py @@ -11,10 +11,10 @@ def condition(cls, hf_config): return hf_config.model_type == 'qwen' @classmethod - def build(cls, hf_config, model_path: str = None): + def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" from lmdeploy.utils import is_bf16_supported - cfg = DefaultModelConfigBuilder.build(hf_config) + cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs) if cfg.bos_token_id is None: cfg.bos_token_id = 151644 if cfg.eos_token_id is None: diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index e393adeed3..e3f97cfe46 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -44,7 +44,13 @@ def __init__( self.num_layers = model_config.num_layers self.kv_cache_dtype = model_config.dtype if cache_config.quant_policy > 0: - self.kv_cache_dtype = torch.uint8 + if self.cache_config.device_type in ['cuda']: + self.kv_cache_dtype = torch.uint8 + elif self.cache_config.device_type in ['ascend', 'npu']: + self.kv_cache_dtype = torch.int8 + else: + raise ValueError( + f'unsupported device_type {self.cache_config.device_type}') # Initialize the cache. self.local_gpu_cache = self.allocate_gpu_cache() @@ -92,7 +98,7 @@ def _get_key_block_shape_impl(cls, attn_backend = get_backend() dtype = model_config.dtype num_heads = model_config.num_key_value_heads - if local and not model_config.multi_query_attention: + if local: assert num_heads % world_size == 0, \ f'num_heads: {num_heads}, world_size: {world_size}' num_heads = num_heads // world_size @@ -115,7 +121,7 @@ def _get_value_block_shape_impl(cls, attn_backend = get_backend() dtype = model_config.dtype num_heads = model_config.num_key_value_heads - if local and not model_config.multi_query_attention: + if local: assert num_heads % world_size == 0, \ f'num_heads: {num_heads}, world_size: {world_size}' num_heads = num_heads // world_size @@ -202,7 +208,7 @@ def allocate_gpu_cache(self): def allocate_cpu_cache(self): """allocate caches on Host.""" - caches = self._allocate_cache(self.num_gpu_blocks, 'cpu') + caches = self._allocate_cache(self.num_cpu_blocks, 'cpu') self.full_cpu_cache = caches self.local_cpu_cache = list(zip(*caches)) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index b7a803a7a7..e06e0cf80a 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -8,19 +8,17 @@ import numpy as np import torch -from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig, - ResponseType) +from lmdeploy.messages import PytorchEngineConfig, ResponseType from lmdeploy.utils import (get_logger, get_max_batch_size, get_model, logging_timer) from ..adapter.adapter import AdapterManager -from ..check_env import check_adapters, check_env, check_model from ..config import BackendConfig, CacheConfig, SchedulerConfig from ..devices import DeviceContext, get_device_manager -from ..messages import (InputEmbeddingRangeType, InputEmbeddingType, - MessageStatus, SchedulerSequence) -from ..model_inputs import ModelInputs, MRopeModelInputs, VisionModelInputs +from ..messages import MessageStatus, SchedulerSequence +from ..model_inputs import ModelInputs, VisionModelInputs from ..paging import Scheduler +from .engine_checker import EngineChecker from .logits_process import FusedLogitsProcessor, SamplingInputs from .model_agent import build_model_agent from .request import Request, RequestManager, RequestType, Response @@ -78,6 +76,40 @@ def _check_finish(scheduler: Scheduler, current_iter: int): return False +def _build_scheduler_config(engine_config: PytorchEngineConfig): + """build scheduler config.""" + scheduler_config = SchedulerConfig( + max_batches=engine_config.max_batch_size, + max_session_len=engine_config.session_len, + prefill_interval=engine_config.prefill_interval) + return scheduler_config + + +def _build_cache_config(engine_config: PytorchEngineConfig): + """build cache config.""" + cache_config = CacheConfig( + max_batches=engine_config.max_batch_size, + block_size=engine_config.block_size, + num_cpu_blocks=engine_config.num_cpu_blocks, + num_gpu_blocks=engine_config.num_gpu_blocks, + cache_max_entry_count=engine_config.cache_max_entry_count, + max_prefill_token_num=engine_config.max_prefill_token_num, + enable_prefix_caching=engine_config.enable_prefix_caching, + quant_policy=engine_config.quant_policy, + device_type=engine_config.device_type, + ) + return cache_config + + +def _build_backend_config(engine_config: PytorchEngineConfig): + """build backend config.""" + backend_config = BackendConfig( + eager_mode=engine_config.eager_mode, + device_type=engine_config.device_type, + ) + return backend_config + + class Engine: """The inference engine of lmdeploy pytorch. @@ -95,43 +127,23 @@ def __init__(self, engine_config = PytorchEngineConfig() else: engine_config = copy.deepcopy(engine_config) - check_env(engine_config.device_type) - check_model(model_path, trust_remote_code, engine_config.dtype, - engine_config.device_type) if engine_config.max_batch_size is None: engine_config.max_batch_size = get_max_batch_size( engine_config.device_type) - adapters = engine_config.adapters - if adapters is not None: - check_adapters(list(adapters.values())) - assert engine_config.max_batch_size > 0, 'max_batch_size should be' \ - f' greater than 0, but got {engine_config.max_batch_size}' - assert engine_config.dtype in ['auto', 'float16', 'bfloat16'], \ - f'unsupported specified data type {engine_config.dtype}' + checker = EngineChecker(model_path=model_path, + engine_config=engine_config, + trust_remote_code=trust_remote_code, + logger=logger) + checker.handle() + + adapters = engine_config.adapters self.engine_config = engine_config self.tp = engine_config.tp self.device_context = DeviceContext( device_type=engine_config.device_type) - scheduler_config = SchedulerConfig( - max_batches=engine_config.max_batch_size, - max_session_len=engine_config.session_len, - prefill_interval=engine_config.prefill_interval) - - # block_size = 1 to enable unified paging - cache_config = CacheConfig( - max_batches=engine_config.max_batch_size, - block_size=engine_config.block_size, - num_cpu_blocks=engine_config.num_cpu_blocks, - num_gpu_blocks=engine_config.num_gpu_blocks, - cache_max_entry_count=engine_config.cache_max_entry_count, - max_prefill_token_num=engine_config.max_prefill_token_num, - enable_prefix_caching=engine_config.enable_prefix_caching, - quant_policy=engine_config.quant_policy, - ) - if not os.path.exists(model_path): model_path = get_model(model_path, engine_config.download_dir, engine_config.revision) @@ -140,10 +152,9 @@ def __init__(self, if adapters is not None and len(adapters) > 0: adapters = self._download_adapters(adapters, engine_config) - backend_config = BackendConfig( - eager_mode=engine_config.eager_mode, - device_type=engine_config.device_type, - ) + scheduler_config = _build_scheduler_config(engine_config) + cache_config = _build_cache_config(engine_config) + backend_config = _build_backend_config(engine_config) with get_device_manager().context(self.device_context): self.model_agent = build_model_agent( @@ -156,6 +167,8 @@ def __init__(self, dtype=engine_config.dtype, custom_module_map=engine_config.custom_module_map) + self.input_processor = self.model_agent.get_input_processor() + cache_config = self.model_agent.cache_config self.adapter_manager = self._build_adapter_manager(adapters) self.scheduler = Scheduler(scheduler_config, cache_config) @@ -171,7 +184,6 @@ def __init__(self, # create main thread self._start_loop() self._create_buffers() - self.engine_instance = self.create_instance() self._output_stream = torch.cuda.Stream() @classmethod @@ -316,6 +328,10 @@ def _on_end_session(self, reqs: Request, **kwargs): def _on_add_message(self, reqs: Request, **kwargs): """on add message callback.""" + self._msg_preprocess_inque.put_nowait(reqs) + + def _add_message(self, que): + def __update_bad_words(msg): """update bad words.""" sampling_param = msg.sampling_param @@ -337,6 +353,11 @@ def __update_max_new_tokens(msg): sampling_param.max_new_tokens, max_session_len - msg.num_all_tokens()) + if que.qsize() == 0: + return + + reqs = que.get_nowait() + for req in reqs: session_id = req.data['session_id'] if session_id not in self.scheduler.sessions: @@ -354,11 +375,8 @@ def __update_max_new_tokens(msg): sampling_param=req.data['sampling_param'], adapter_name=req.data['adapter_name'], return_logits=req.data.get('return_logits', False), + multimodals=req.data.get('input_multimodals'), input_embeddings=req.data.get('input_embeddings'), - mrope_position_ids=req.data.get('mrope_position_ids'), - mrope_position_delta=req.data.get('mrope_position_delta'), - cross_attention_states=req.data.get( - 'cross_attention_states'), ) msg = next(iter(sess.sequences.values())) __update_bad_words(msg) @@ -366,9 +384,11 @@ def __update_max_new_tokens(msg): self.scheduler.add_sequence(msg) else: msg = next(iter(sess.sequences.values())) - msg.update_token_ids(req.data['token_ids'], - req.data.get('input_embeddings'), - req.data.get('cross_attention_states')) + msg.update_token_ids( + req.data['token_ids'], + multimodals=req.data.get('input_multimodals'), + embeddings=req.data.get('input_embeddings'), + ) msg.num_new_tokens = 0 msg.sampling_param = req.data['sampling_param'] msg.return_logits = req.data.get('return_logits', False) @@ -414,7 +434,6 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): seq_length = self._seq_length_buf[:batch_size] max_q_seq_length = seq_length.max().item() - # TODO: get block offsets is slow when block_size = 1 block_offsets = self.scheduler.get_block_tables(messages) block_offsets = _tensorlize_block_offsets(block_offsets) @@ -432,13 +451,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): num_ignored_history = [msg.num_ignored_history for msg in messages] num_ignored_history = torch.tensor(num_ignored_history) - def __get_cogvlm_image_info(): - """Get cogvlm history image info for position ids.""" - history_image_nums = torch.LongTensor( - [msg.history_image_num for msg in messages]) - history_image_token_lengths = torch.LongTensor( - [msg.history_image_token_len for msg in messages]) - return history_image_nums, history_image_token_lengths + model_metas = [msg.model_meta for msg in messages] def __get_vlm_embeddings(): """get vlm input embeddings and indexings.""" @@ -463,25 +476,9 @@ def __get_vlm_embeddings(): return (input_embeddings, input_embedding_indexing, input_embedding_ranges) - def __get_mrope_inputs(): - """get multimodal rotary position inputs.""" - position_ids = [msg.mrope_position_ids for msg in messages] - deltas = [msg.mrope_position_delta for msg in messages] - return MRopeModelInputs(position_ids=position_ids, deltas=deltas) - # for inputs with embeddings history_image_nums = None history_image_token_lengths = None - # only for cogvlm - if self.model_config.cogvlm_style: - (history_image_nums, - history_image_token_lengths) = __get_cogvlm_image_info() - # only for qwen2_vl - mrope_inputs = None - has_mrope_params = any( - [msg.mrope_position_ids is not None for msg in messages]) - if has_mrope_params: - mrope_inputs = __get_mrope_inputs() input_embeddings = None input_embedding_indexing = None @@ -492,25 +489,40 @@ def __get_mrope_inputs(): (input_embeddings, input_embedding_indexing, input_embedding_ranges) = __get_vlm_embeddings() + input_multimodals = None + has_multimodal = any( + [not msg.history_multimodals.empty() for msg in messages]) + if has_multimodal: + has_multimodal = False + input_multimodals = [ + msg.get_input_multimodals() for msg in messages + ] + for input_mm in input_multimodals: + for val in input_mm.values(): + if len(val) > 0: + has_multimodal = True + break + if has_multimodal: + break + vision_embedding_inputs = None - if has_embedding or history_image_nums is not None: + if has_embedding or has_multimodal or history_image_nums is not None: vision_embedding_inputs = VisionModelInputs( history_lengths=history_lengths, history_image_nums=history_image_nums, history_image_token_lengths=history_image_token_lengths, input_embeddings=input_embeddings, input_embedding_indexing=input_embedding_indexing, - input_embedding_ranges=input_embedding_ranges) - - # only for mllama - cross_attention_states = None - history_cross_kv_seqlens = None - if any([msg.cross_attention_states is not None for msg in messages]): - cross_attention_states = [ - msg.cross_attention_states for msg in messages - ] - history_cross_kv_seqlens = torch.tensor( - [msg.history_cross_kv_seqlens for msg in messages]) + input_embedding_ranges=input_embedding_ranges, + input_multimodals=input_multimodals) + + # cross + cross_length = torch.tensor([msg.num_cross for msg in messages]) + history_cross_length = torch.tensor( + [msg.num_history_cross for msg in messages]) + if (cross_length + history_cross_length).max().item() == 0: + cross_length = None + history_cross_length = None return ModelInputs( input_ids=input_ids, @@ -521,9 +533,9 @@ def __get_mrope_inputs(): num_ignored_history=num_ignored_history, local_adapter_ids=local_adapter_ids, vision_inputs=vision_embedding_inputs, - mrope_inputs=mrope_inputs, - cross_attention_states=cross_attention_states, - history_cross_kv_seqlens=history_cross_kv_seqlens, + cross_length=cross_length, + history_cross_length=history_cross_length, + model_metas=model_metas, ) def _batch_stopping_criteria(self, token_ids: torch.Tensor, @@ -567,11 +579,15 @@ def __get_last_logits(): @logging_timer('UpdateRunning', logger) def update_running(self, running: SeqList, next_token_ids: torch.Tensor, - stopped: torch.Tensor): + stopped: torch.Tensor, model_metas: List[Dict[str, + Any]]): """update scheduler.""" + if model_metas is None: + model_metas = [None] * len(running) next_token_ids = next_token_ids.numpy() eos_token_id = self.model_config.eos_token_id - for token, msg, stop in zip(next_token_ids, running, stopped): + for token, msg, stop, model_meta in zip(next_token_ids, running, + stopped, model_metas): if msg.status != MessageStatus.RUNNING: continue update_token = token @@ -580,7 +596,7 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor, update_token = _EMPTY_TOKEN else: msg.num_new_tokens += 1 - msg.update_token_ids(update_token) + msg.update_token_ids(update_token, model_meta=model_meta) if stop: msg.status = MessageStatus.STOPPED @@ -646,12 +662,14 @@ async def __long_context_single_forward(inputs): batch_size = seq_len.size(0) assert batch_size == 1 - new_inputs = inputs.split(max_prefill_token_num, - self.cache_config.block_size) + new_inputs = inputs.split(max_prefill_token_num) + model_metas = new_inputs[0].model_metas output_gather = _OutputGather(max_seq_len) for inp in new_inputs: + inp.model_metas = model_metas tmp_out = await __forward(inp) + model_metas = tmp_out.get('model_metas') output_gather.gather(tmp_out) tmp_out.pop('hidden_states', None) tmp_out['hidden_states'] = output_gather.get_output() @@ -673,9 +691,10 @@ async def __long_context_single_forward(inputs): ret['logits'] = logits return ret - def _make_infer_outputs(self, next_token_ids: torch.LongTensor, - logits: torch.Tensor, stopped: torch.Tensor, - event: torch.cuda.Event): + async def _make_infer_outputs(self, next_token_ids: torch.LongTensor, + logits: torch.Tensor, stopped: torch.Tensor, + model_metas: List[Dict[str, Any]], + event: torch.cuda.Event): """make infer output.""" def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence, @@ -696,15 +715,16 @@ def __get_q_start_loc(): else: return seq_length.cumsum(0) - seq_length + while not event.query(): + await asyncio.sleep(0.001) with torch.cuda.stream(self._output_stream): - event.wait() next_token_ids = next_token_ids.cpu() stopped = stopped.cpu() running = self._running is_run = [seq.status == MessageStatus.RUNNING for seq in running] stopped = stopped.tolist() - self.update_running(running, next_token_ids, stopped) + self.update_running(running, next_token_ids, stopped, model_metas) # generate output next_token_ids = next_token_ids.tolist() @@ -762,8 +782,7 @@ def __update_inputs(next_token_ids): logger.debug(': ' f'batch_size={inputs.seq_length.size(0)} ' f'num_tokens={inputs.input_ids.size(-1)}') - if self.gpu_count == 1: - inputs = inputs.to_device('cuda') + inputs = inputs.to_device('cuda') is_decoding = inputs.is_decoding if all_ids is not None: all_ids = all_ids.cuda() @@ -794,13 +813,16 @@ def __update_inputs(next_token_ids): next_token_ids, sampling_inputs.stop_words, num_appendable_ids) # send output + model_metas = output.get('model_metas') finish = (idx == loop_count - 1) finish = finish or _check_finish(self.scheduler, idx) event = torch.cuda.Event() event.record() - output = (next_token_ids, logits, stopped, event) + output = (next_token_ids, logits, stopped, model_metas, event) output_que.put_nowait((finish, output)) + inputs.model_metas = model_metas + if finish: break @@ -810,6 +832,36 @@ def __update_inputs(next_token_ids): swap_out_map = dict() __update_inputs(next_token_ids) + @torch.inference_mode() + async def _async_loop_preprocess_message(self, inque, outque): + """preprocess msg.""" + while True: + reqs = await inque.get() + + for req in reqs: + req_data = req.data + if req_data.get('input_multimodals', None) is None: + continue + elif self.input_processor is None: + logger.warning('Do not support Multimodal inputs.') + continue + input_ids = req_data['token_ids'] + input_multimodals = req_data['input_multimodals'] + if len(input_multimodals) == 0: + req_data['input_multimodals'] = None + continue + result = self.input_processor.preprocess_input( + input_ids, input_multimodals) + + input_ids = result.input_ids + input_multimodals = result.input_multimodals + + req_data['token_ids'] = input_ids + req_data['input_multimodals'] = input_multimodals + + if len(reqs) > 0: + outque.put_nowait(reqs) + @torch.inference_mode() async def _async_loop_background(self, in_que: asyncio.Queue, out_que: asyncio.Queue): @@ -918,6 +970,10 @@ async def _async_loop(self): Each engine instance would communicate with the engine by queue. """ + + self._msg_preprocess_inque = asyncio.Queue() + self._msg_preprocess_outque = asyncio.Queue() + prefill_interval = self.scheduler_config.prefill_interval in_que = asyncio.Queue() out_que = asyncio.Queue() @@ -926,6 +982,12 @@ async def _async_loop(self): name='MainLoopBackground') loop_background.add_done_callback(_raise_exception_on_finish) + loop_msg_proc = asyncio.get_event_loop().create_task( + self._async_loop_preprocess_message(self._msg_preprocess_inque, + self._msg_preprocess_outque), + name='MainLoopPreprocessMessage') + loop_msg_proc.add_done_callback(_raise_exception_on_finish) + def __send_resp(out: InferOutput): """send response.""" resp_type = (ResponseType.FINISH @@ -957,13 +1019,14 @@ async def __step(): while not finish: if self.req_manager.has_requests(): self.req_manager.step() + self._add_message(self._msg_preprocess_outque) finish, out = await out_que.get() try: if isinstance(out, Exception): raise out - next_token_ids, logits, stopped, event = out - step_outputs = self._make_infer_outputs( - next_token_ids, logits, stopped, event) + (next_token_ids, logits, stopped, model_metas, event) = out + step_outputs = await self._make_infer_outputs( + next_token_ids, logits, stopped, model_metas, event) __send_resps(step_outputs) except Exception as e: raise e @@ -973,6 +1036,7 @@ async def __step(): while True: if self.req_manager.has_requests(): self.req_manager.step() + self._add_message(self._msg_preprocess_outque) if not self.scheduler.has_unfinished(): await asyncio.sleep(0.01) @@ -996,78 +1060,3 @@ def create_instance(self, cuda_stream_id=0): """ from .engine_instance import EngineInstance return EngineInstance(self) - - async def async_batched_infer( - self, - session_ids: List[int], - token_ids: List[List[int]] = None, - gen_config: GenerationConfig = None, - adapter_names: List[str] = None, - keep_cache: bool = False, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None): - """Send inference request. - - Args: - session_ids (List[int]): The session id. - token_ids (List[int]): The input token ids. - gen_config (GenerationConfig): The sampling parameters. - adapter_names (List[str]): The name of the adapters. - keep_cache (bool): Keep kv cache after infer. - - Returns: - int: Error flags. 0 if success. - List[int]: The streaming output tokens. - int: The number of the output tokens. - """ - return await self.engine_instance.async_batched_infer( - session_ids=session_ids, - token_ids=token_ids, - gen_config=gen_config, - adapter_names=adapter_names, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - keep_cache=keep_cache) - - def batched_infer( - self, - session_ids: List[int], - token_ids: List[List[int]] = None, - gen_config: GenerationConfig = None, - adapter_names: List[str] = None, - keep_cache: bool = False, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None): - """batched infer.""" - return self.engine_instance.batched_infer( - session_ids=session_ids, - token_ids=token_ids, - gen_config=gen_config, - adapter_names=adapter_names, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - keep_cache=keep_cache) - - async def async_add_session(self, session_id: int): - """Add new session.""" - return await self.engine_instance._async_try_add_session(session_id) - - def add_session(self, session_id: int): - """Add new session.""" - return self.engine_instance._try_add_session(session_id) - - async def async_cancel(self, session_id: int): - """Stop the given session.""" - return await self.engine_instance.async_cancel(session_id) - - def cancel(self, session_id: int): - """Add new session.""" - return self.engine_instance.cancel(session_id) - - async def async_end(self, session_id: int): - """End the given session.""" - return await self.engine_instance.async_end(session_id) - - def end(self, session_id: int): - """Add new session.""" - return self.engine_instance.end(session_id) diff --git a/lmdeploy/pytorch/engine/engine_checker.py b/lmdeploy/pytorch/engine/engine_checker.py new file mode 100644 index 0000000000..7276a51fbc --- /dev/null +++ b/lmdeploy/pytorch/engine/engine_checker.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.messages import PytorchEngineConfig + +from ..check_env.adapter import AdapterChecker +from ..check_env.base import BaseChecker +from ..check_env.model import ModelChecker +from ..check_env.torch import TorchChecker +from ..check_env.transformers import TransformersChecker + + +class EngineChecker(BaseChecker): + """check transformers is available.""" + + def __init__(self, + model_path: str, + engine_config: PytorchEngineConfig, + trust_remote_code: bool = True, + logger=None): + super().__init__(logger) + logger = self.get_logger() + + self.engine_config = engine_config + + dtype = engine_config.dtype + device_type = engine_config.device_type + + # pytorch + torch_checker = TorchChecker(logger=logger) + + if device_type == 'cuda': + # triton + from ..check_env.triton import TritonChecker + triton_checker = TritonChecker(logger=logger) + triton_checker.register_required_checker(torch_checker) + self.register_required_checker(triton_checker) + else: + # deeplink + from ..check_env.deeplink import DeeplinkChecker + dl_checker = DeeplinkChecker(device_type, logger=logger) + self.register_required_checker(dl_checker) + self.register_required_checker(torch_checker) + + # transformers + + # model + trans_checker = TransformersChecker() + model_checker = ModelChecker(model_path=model_path, + trust_remote_code=trust_remote_code, + dtype=dtype, + device_type=device_type, + logger=logger) + model_checker.register_required_checker(torch_checker) + model_checker.register_required_checker(trans_checker) + self.register_required_checker(model_checker) + + # adapters + adapters = engine_config.adapters + if adapters is not None: + adapter_paths = list(adapters.values()) + for adapter in adapter_paths: + adapter_checker = AdapterChecker(adapter, logger=logger) + self.register_required_checker(adapter_checker) + + def check(self): + """check.""" + engine_config = self.engine_config + logger = self.get_logger() + + if engine_config.thread_safe: + logger.warning('thread safe mode has been deprecated and' + ' it would be removed in the future.') + + if engine_config.max_batch_size <= 0: + self.log_and_exit( + mod_name='Engine', + message='max_batch_size should be' + f' greater than 0, but got {engine_config.max_batch_size}') diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index 455ab1ccb3..dff9667eb4 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -1,16 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List +from typing import Any, Dict, List from lmdeploy.messages import EngineOutput, GenerationConfig from lmdeploy.utils import get_logger -from ..messages import (InputEmbeddingRangeType, InputEmbeddings, - InputEmbeddingType, SamplingParam) +from ..messages import SamplingParam from .engine import Engine from .request import RequestSender, RequestType, Response, ResponseType logger = get_logger('lmdeploy') +InputMultiModalType = List[Dict[str, Any]] + def _check_resp(resp: Response, state: ResponseType, warning_msg: str = None): """check if response has state.""" @@ -114,15 +115,13 @@ def _try_add_session(self, session_id: int): """ return try_add_session(self.req_sender, session_id) - async def async_stream_infer( - self, - session_id: int, - input_ids: List[int], - gen_config: GenerationConfig = None, - adapter_name: str = None, - input_embeddings: InputEmbeddingType = None, - input_embedding_ranges: InputEmbeddingRangeType = None, - **kwargs): + async def async_stream_infer(self, + session_id: int, + input_ids: List[int], + gen_config: GenerationConfig = None, + multimodal: InputMultiModalType = None, + adapter_name: str = None, + **kwargs): """Send stream inference request. Args: @@ -144,21 +143,13 @@ async def async_stream_infer( await self.req_sender.async_send_async( RequestType.ADD_SESSION, dict(session_id=session_id, response=False)) - input_embeddings_new: List[InputEmbeddings] = None - if input_embeddings is not None and len(input_embeddings) > 0: - assert len(input_embeddings) == len(input_embedding_ranges) - input_embeddings_new = [ - InputEmbeddings(emb, rg[0], rg[1]) - for emb, rg in zip(input_embeddings, input_embedding_ranges) - ] - msg = dict(token_ids=input_ids, - session_id=session_id, - sampling_param=sampling_param, - adapter_name=adapter_name, - input_embeddings=input_embeddings_new, - mrope_position_ids=kwargs.get('mrope_position_ids'), - mrope_position_delta=kwargs.get('mrope_position_delta'), - cross_attention_states=kwargs.get('cross_attention_states')) + msg = dict( + token_ids=input_ids, + session_id=session_id, + sampling_param=sampling_param, + adapter_name=adapter_name, + input_multimodals=multimodal, + ) req_id = await self.req_sender.async_send_async( RequestType.ADD_MESSAGE, msg) @@ -179,14 +170,12 @@ async def async_stream_infer( yield EngineOutput(resp.type, [], 0) break - async def async_infer( - self, - session_id: int, - input_ids: List[int] = None, - gen_config: GenerationConfig = None, - input_embeddings: InputEmbeddingType = None, - input_embedding_ranges: InputEmbeddingRangeType = None, - **kwargs): + async def async_infer(self, + session_id: int, + input_ids: List[int] = None, + multimodal: InputMultiModalType = None, + gen_config: GenerationConfig = None, + **kwargs): """Send inference request. Args: @@ -200,13 +189,11 @@ async def async_infer( int: The number of the output tokens. """ token_ids = [] - async for outputs in self.async_stream_infer( - session_id, - input_ids, - gen_config=gen_config, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - **kwargs): + async for outputs in self.async_stream_infer(session_id, + input_ids, + multimodal=multimodal, + gen_config=gen_config, + **kwargs): status, tmp_ids = outputs.status, outputs.token_ids if status not in [ResponseType.SUCCESS, ResponseType.FINISH]: return EngineOutput(status, token_ids, len(token_ids)) @@ -217,10 +204,9 @@ async def async_infer( def stream_infer(self, session_id: int, input_ids: List[int], + multimodal: InputMultiModalType = None, gen_config: GenerationConfig = None, adapter_name: str = None, - input_embeddings: InputEmbeddingType = None, - input_embedding_ranges: InputEmbeddingRangeType = None, **kwargs): """Send stream inference request. @@ -241,14 +227,12 @@ def stream_infer(self, def __call_async(): """call async.""" - coro_gen = self.async_stream_infer( - session_id, - input_ids, - gen_config, - adapter_name, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - **kwargs) + coro_gen = self.async_stream_infer(session_id, + input_ids, + multimodal=multimodal, + gen_config=gen_config, + adapter_name=adapter_name, + **kwargs) while True: try: yield self.req_sender.run_until_complete( @@ -264,19 +248,12 @@ def __call_async(): sampling_param = SamplingParam.from_gen_config(gen_config=gen_config) self.req_sender.send_async(RequestType.ADD_SESSION, dict(session_id=session_id, response=False)) - input_embeddings_new: List[InputEmbeddings] = None - if input_embeddings is not None and len(input_embeddings) > 0: - assert len(input_embeddings) == len(input_embedding_ranges) - input_embeddings_new = [ - InputEmbeddings(emb, rg[0], rg[1]) - for emb, rg in zip(input_embeddings, input_embedding_ranges) - ] msg = dict( token_ids=input_ids, session_id=session_id, sampling_param=sampling_param, adapter_name=adapter_name, - input_embeddings=input_embeddings_new, + input_multimodals=multimodal, ) req_id = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg) @@ -300,9 +277,8 @@ def __call_async(): def infer(self, session_id: int, input_ids: List[int] = None, + multimodal: InputMultiModalType = None, gen_config: GenerationConfig = None, - input_embeddings: InputEmbeddingType = None, - input_embedding_ranges: InputEmbeddingRangeType = None, **kwargs): """Send inference request. @@ -317,13 +293,11 @@ def infer(self, int: The number of the output tokens. """ token_ids = [] - for outputs in self.stream_infer( - session_id, - input_ids, - gen_config=gen_config, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - **kwargs): + for outputs in self.stream_infer(session_id, + input_ids, + multimodal=multimodal, + gen_config=gen_config, + **kwargs): status, tmp_ids = outputs.status, outputs.token_ids if status not in [ResponseType.SUCCESS, ResponseType.FINISH]: return EngineOutput(status, token_ids, len(token_ids)) @@ -331,127 +305,6 @@ def infer(self, return EngineOutput(0, token_ids, len(token_ids)) - async def async_batched_infer( - self, - session_ids: List[int], - token_ids: List[List[int]] = None, - gen_config: GenerationConfig = None, - adapter_names: List[str] = None, - keep_cache: bool = False, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None, - ): - """Send inference request. - - Args: - session_ids (List[int]): The session id. - token_ids (List[int]): The input token ids. - gen_config (GenerationConfig): The sampling parameters. - adapter_names (List[str]): The name of the adapters. - keep_cache (bool): Keep kv cache after infer. - - Returns: - int: Error flags. 0 if success. - List[int]: The streaming output tokens. - int: The number of the output tokens. - """ - batch_size = len(token_ids) - assert len(session_ids) == batch_size - if adapter_names is not None: - assert len(adapter_names) == batch_size - else: - adapter_names = [None for _ in range(batch_size)] - - if input_embeddings is not None: - assert len(input_embeddings) == batch_size - assert len(input_embedding_ranges) == batch_size - else: - input_embeddings = [None] * batch_size - input_embedding_ranges = [None] * batch_size - - async def _add_sessions(session_ids): - for session_id in session_ids: - await self._async_try_add_session(session_id) - - async def _add_messages(session_ids, token_ids, adapter_names, - input_embeddings, input_embedding_ranges): - add_msgs = [] - sampling_param = SamplingParam.from_gen_config(gen_config) - for session_id, token_id, adapter_name, input_emb, input_ranges in zip( # noqa: E501 - session_ids, token_ids, adapter_names, input_embeddings, - input_embedding_ranges): - cur_input_embeddings: List[InputEmbeddings] = None - if input_emb is not None and len(input_emb) > 0: - assert len(input_emb) == len(input_ranges) - cur_input_embeddings = [ - InputEmbeddings(emb, rg[0], rg[1]) - for emb, rg in zip(input_emb, input_ranges) - ] - msg = dict( - token_ids=token_id, - session_id=session_id, - sampling_param=sampling_param, - adapter_name=adapter_name, - input_embeddings=cur_input_embeddings, - ) - add_msgs.append(msg) - req_types = [RequestType.ADD_MESSAGE] * batch_size - req_ids = await self.req_sender.async_batched_send_async( - req_types, data=add_msgs) - return req_ids - - await _add_sessions(session_ids) - req_ids = await _add_messages(session_ids, token_ids, adapter_names, - input_embeddings, input_embedding_ranges) - - # receive messages - req_idx_map = dict(zip(req_ids, range(len(req_ids)))) - output_token_ids = [list() for _ in req_ids] - status = 0 - finish_count = batch_size - while finish_count: - resp = await self.req_sender.async_recv_any() - if resp.req_id not in req_ids: - continue - idx = req_idx_map[resp.req_id] - token_ids = output_token_ids[idx] - if resp.type == ResponseType.SUCCESS: - token_ids += resp.data['token_ids'] - elif resp.type == ResponseType.FINISH: - token_ids += resp.data['token_ids'] - if not keep_cache: - session_id = session_ids[idx] - await self.async_end(session_id=session_id) - finish_count -= 1 - else: - logger.error(f'Unexpected response: {resp.type}') - status = 1 - break - - output_token_len = [len(token_ids) for token_ids in output_token_ids] - return EngineOutput(status, output_token_ids, output_token_len) - - def batched_infer( - self, - session_ids: List[int], - token_ids: List[List[int]] = None, - gen_config: GenerationConfig = None, - adapter_names: List[str] = None, - keep_cache: bool = False, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None, - ): - """batched infer.""" - coro = self.async_batched_infer( - session_ids, - token_ids, - gen_config=gen_config, - adapter_names=adapter_names, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - keep_cache=keep_cache) - return self.req_sender.run_until_complete(coro) - async def async_end(self, session_id: int): """End the given session.""" return await async_end(self.req_sender, session_id) @@ -470,8 +323,7 @@ def cancel(self, session_id: int): def decode(self, input_ids, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None, + multimodal: List[InputMultiModalType] = None, steps: List[int] = None, sequence_start: bool = True, sequence_end: bool = True, @@ -481,10 +333,8 @@ def decode(self, Args: input_ids (numpy.ndarray): the batch of input token ids steps (List[int]): the offset of the k/v cache - input_embeddings (List[List[Union[torch.Tensor, np.ndarray]]]): - embeddings features - input_embedding_ranges: (List[List[Tuple[int, int]]]): - the begin/end offsets of input_embeddings to input_ids + multimodal (List[InputMultiModalType]): + multimodals inputs. sequence_start (bool): indicator for starting a sequence sequence_end (bool): indicator for ending a sequence adapter_names (List[str]): The name of the adapters. @@ -494,33 +344,24 @@ def decode(self, batch_size = len(input_ids) def __add_messages(session_ids, input_ids, adapter_names, - input_embeddings, input_embedding_ranges): + input_multimodals): add_msgs = [] sampling_param = SamplingParam(max_new_tokens=0) batch_size = len(input_ids) - if input_embeddings is None: - input_embeddings = [None] * batch_size - input_embedding_ranges = [None] * batch_size - for (session_id, token_id, adapter_name, input_emb, - input_ranges) in zip(session_ids, input_ids, adapter_names, - input_embeddings, - input_embedding_ranges): + if input_multimodals is None: + input_multimodals = [None] * batch_size + for (session_id, token_id, adapter_name, + in_mm) in zip(session_ids, input_ids, adapter_names, + input_multimodals): if len(token_id) > self.max_input_len: raise RuntimeError( f'Expect input length<={self.max_input_len} ' f'but get {len(token_id)}') - cur_input_embeddings: List[InputEmbeddings] = None - if input_emb is not None and len(input_emb) > 0: - assert len(input_emb) == len(input_ranges) - cur_input_embeddings = [ - InputEmbeddings(emb, rg[0], rg[1]) - for emb, rg in zip(input_emb, input_ranges) - ] msg = dict(token_ids=token_id, session_id=session_id, sampling_param=sampling_param, adapter_name=adapter_name, - input_embeddings=cur_input_embeddings, + input_multimodals=in_mm, return_logits=True) add_msgs.append(msg) req_types = [RequestType.ADD_MESSAGE] * batch_size @@ -536,13 +377,6 @@ def __add_messages(session_ids, input_ids, adapter_names, else: adapter_names = [None] * batch_size - if input_embeddings is not None: - assert len(input_embeddings) == batch_size - assert len(input_embedding_ranges) == batch_size - else: - input_embeddings = [None] * batch_size - input_embedding_ranges = [None] * batch_size - session_ids = tuple(range(batch_size)) if sequence_start: for sid in session_ids: @@ -551,7 +385,7 @@ def __add_messages(session_ids, input_ids, adapter_names, self._try_add_session(sid) req_ids = __add_messages(session_ids, input_ids, adapter_names, - input_embeddings, input_embedding_ranges) + multimodal) req_idx_map = dict(zip(req_ids, range(len(req_ids)))) finish_count = batch_size diff --git a/lmdeploy/pytorch/engine/input_process.py b/lmdeploy/pytorch/engine/input_process.py new file mode 100644 index 0000000000..7f442e153b --- /dev/null +++ b/lmdeploy/pytorch/engine/input_process.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs + +TypeModelMetas = Dict[str, Any] + +InputMultiModalType = List[Dict[str, Any]] + + +@dataclass +class PreprocessInputResult: + """results of preprocess input.""" + input_ids: List[int] + input_multimodals: Optional[MultiModalInputs] = None + model_metas: Optional[TypeModelMetas] = None + + +class BaseModelInputProcessor(ABC): + """processor of model inputs.""" + + @abstractmethod + def preprocess_input(self, + input_ids: List[int], + input_mms: InputMultiModalType = None, + **kwargs) -> PreprocessInputResult: + """preprocess input.""" + raise NotImplementedError('Not implemented.') + + +class DefaultModelInputProcessor(BaseModelInputProcessor): + """default model input processor.""" + + def preprocess_input(self, + input_ids: List[int], + input_mms: MultiModalInputs = None, + **kwargs) -> PreprocessInputResult: + """preprocess input.""" + return PreprocessInputResult( + input_ids=input_ids, + input_multimodals=input_mms, + ) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 59d77f264a..5487639d29 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -135,21 +135,26 @@ def model_forward( stream = stream or torch.cuda.current_stream() with torch.cuda.stream(stream): # forward - inputs = inputs.to_device('cuda') ctx_mgr = model.ctx_mgr context = ctx_mgr.build_context( inputs=inputs, + model_config=cache_engine.model_config, world_size=world_size, kv_caches=cache_engine.gpu_cache, kv_quant_policy=cache_engine.cache_config.quant_policy, ) with ctx_mgr.context(context): + model_metas = None + model_metas = model.update_model_metas( + past_key_values=cache_engine.gpu_cache, + context=context, + ) input_dict = model.prepare_inputs_for_generation( past_key_values=cache_engine.gpu_cache, context=context, ) output = model(**input_dict) - return dict(hidden_states=output) + return dict(hidden_states=output, model_metas=model_metas) SwapMap = Dict[int, int] @@ -177,6 +182,10 @@ def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" raise NotImplementedError('Not implemented.') + def get_input_processor(self): + """get input processor.""" + raise NotImplementedError('Not implemented.') + class BaseModelAgent(AutoModelAgent): """Base model agent. @@ -267,14 +276,16 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_in_map=swap_in_map, swap_out_map=swap_out_map) await asyncio.sleep(0) - while not self.stream.query(): - await asyncio.sleep(0) return output def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" return self.patched_model.get_logits(hidden_states) + def get_input_processor(self): + """get input processor..""" + return self.patched_model.get_input_processor() + @torch.inference_mode() def _tp_build_model( @@ -360,14 +371,26 @@ def _broadcast_config(cache_config): return patched_model, cache_engine, cache_config -def _broadcast_inputs(rank: int, inputs: Any, stream: torch.cuda.Stream): +def _broadcast_inputs(rank: int, inputs: Any, group: dist.group, + stream: torch.cuda.Stream): """get input tensor parallel.""" # broadcast meta info if rank != 0: inputs = [None, None, None] + else: + device_inputs = inputs[0] + meta_inputs = device_inputs.to_device('meta') + inputs[0] = meta_inputs with torch.cuda.stream(stream): - dist.broadcast_object_list(inputs) + dist.broadcast_object_list(inputs, group=group) + if rank == 0: + device_inputs.broadcast() + else: + device_inputs = inputs[0].broadcast() + + inputs[0] = device_inputs + return inputs @@ -380,6 +403,7 @@ def _tp_model_loop( adapters: Dict[str, str], world_size: int, barrier: mp.Barrier, + cpu_group: dist.group, ): """Start model loops for tensor parallel model inference. @@ -405,11 +429,12 @@ def _tp_model_loop( while True: barrier.wait() inputs, swap_in_map, swap_out_map = _broadcast_inputs( - rank, None, stream) + rank, None, cpu_group, stream) cache_swapping(cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) + inputs = inputs.to_device('cuda') model_forward( patched_model, @@ -441,10 +466,13 @@ def _start_tp_process(proc_id: int, try: from lmdeploy.pytorch.check_env import check_env_deeplink check_env_deeplink(device_context.device_type) + timeout = timedelta(days=35600) dist.init_process_group('nccl', rank=rank, world_size=world_size, - timeout=timedelta(days=35600)) + timeout=timeout) + cpu_group = dist.new_group(timeout=timeout, backend='gloo') + kwargs['cpu_group'] = cpu_group dist_ctx = DistContext(rank=rank, world_size=world_size) torch.cuda.set_device(rank) with get_dist_manager().context(dist_ctx), get_device_manager( @@ -614,12 +642,15 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig, rank = 0 try: + timeout = timedelta(days=35600) dist.init_process_group('nccl', rank=rank, world_size=world_size, - timeout=timedelta(days=35600)) + timeout=timeout) + cpu_group = dist.new_group(timeout=timeout, backend='gloo') dist_ctx = DistContext(rank=rank, world_size=world_size) self._dist_ctx = dist_ctx + self._cpu_group = cpu_group except Exception as e: from traceback import print_exc logger.error(f'Rank[{rank}] failed.') @@ -661,7 +692,8 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, self.mp_bar.wait() rank = 0 _broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map], - self.stream) + self._cpu_group, self.stream) + cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) @@ -687,14 +719,16 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_in_map=swap_in_map, swap_out_map=swap_out_map) await asyncio.sleep(0) - while not self.stream.query(): - await asyncio.sleep(0) return output def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" return self.patched_model.get_logits(hidden_states) + def get_input_processor(self): + """get input processor..""" + return self.patched_model.get_input_processor() + def _exit_handler(agent: TPModelAgent): if hasattr(agent, 'patched_model'): @@ -722,7 +756,7 @@ def build_model_agent(model_path: str, custom_module_map (str): customized nn module map """ model_config = ModelConfig.from_pretrained( - model_path, trust_remote_code=trust_remote_code, dtype=dtype) + model_path, trust_remote_code=trust_remote_code, dtype=dtype, tp=tp) model_config.custom_module_map = custom_module_map if tp == 1: model_agent = BaseModelAgent(model_path, diff --git a/lmdeploy/pytorch/kernels/cuda/flashattention.py b/lmdeploy/pytorch/kernels/cuda/flashattention.py index 34a11ae030..3d07225e43 100644 --- a/lmdeploy/pytorch/kernels/cuda/flashattention.py +++ b/lmdeploy/pytorch/kernels/cuda/flashattention.py @@ -47,6 +47,17 @@ def softcapping(qk, logit_softcapping: tl.constexpr): return qk +@triton.jit +def _load_kv(ptrs, causal_mask: tl.constexpr, boundary_check: tl.constexpr): + """load kv.""" + if causal_mask: + return tl.load(ptrs, + boundary_check=boundary_check, + padding_option='zero') + else: + return tl.load(ptrs) + + @triton.jit def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start, loop_end, sm_scale, history_mask, @@ -63,11 +74,11 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, for start_n in range(loop_start, loop_end, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - k = tl.load(k_ptrs) + k = _load_kv(k_ptrs, causal_mask, boundary_check=(1, )) qk = tl.dot(q, k) if BLOCK_DK1 != 0: - k1 = tl.load(k1_ptrs) + k1 = _load_kv(k1_ptrs, causal_mask, boundary_check=(1, )) qk += tl.dot(q1, k1) if causal_mask: @@ -117,7 +128,7 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, acc = acc * alpha[:, None] # update acc - v = tl.load(v_ptrs) + v = _load_kv(v_ptrs, causal_mask, boundary_check=(0, )) p = p.to(v.dtype) acc += tl.dot(p, v) # update m_i and l_i @@ -172,6 +183,7 @@ def _flash_prefill_fwd_kernel( kv_group_num, head_dim_k, head_dim_v, + causal: tl.constexpr, window_size: tl.constexpr, logit_softcapping: tl.constexpr, BLOCK_M: tl.constexpr, @@ -260,9 +272,13 @@ def _flash_prefill_fwd_kernel( l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) - history_mask = history_len + start_m * BLOCK_M + tl.arange(0, BLOCK_M) + if causal: + history_mask = history_len + start_m * BLOCK_M + tl.arange(0, BLOCK_M) + loop_end = (history_len + start_m * BLOCK_M) // BLOCK_N * BLOCK_N + else: + history_mask = tl.full([BLOCK_M], kv_seqlen - 1, dtype=tl.int32) + loop_end = kv_seqlen // BLOCK_N * BLOCK_N - loop_end = (history_len + start_m * BLOCK_M) // BLOCK_N * BLOCK_N acc, l_i, m_i = _prefill_fwd_inner(acc, l_i, m_i, @@ -283,7 +299,10 @@ def _flash_prefill_fwd_kernel( BLOCK_DK1=BLOCK_DK1) loop_start = loop_end - loop_end = tl.minimum(kv_seqlen, loop_start + BLOCK_M + BLOCK_N) + if causal: + loop_end = tl.minimum(kv_seqlen, loop_start + BLOCK_M + BLOCK_N) + else: + loop_end = kv_seqlen acc, l_i, m_i = _prefill_fwd_inner(acc, l_i, m_i, @@ -333,6 +352,7 @@ def flash_attention_fwd( window_size: int = None, sm_scale: float = None, logit_softcapping: float = None, + causal: bool = True, kv_layout: str = 'hsd', ): """varlen flash Attention forward. @@ -383,6 +403,7 @@ def grid(args): BLOCK_M = max(16, 8192 // BLOCK_DK) else: BLOCK_M = max(16, 16384 // BLOCK_DK) + BLOCK_M = min(128, BLOCK_M) num_warps = 4 num_stages = min(4, max(2, 1024 // BLOCK_DK)) if BLOCK_DK >= 512: @@ -416,6 +437,7 @@ def grid(args): kv_group_num=kv_group_num, head_dim_k=head_dim_k, head_dim_v=head_dim_v, + causal=causal, window_size=window_size, logit_softcapping=logit_softcapping, BLOCK_DK=BLOCK_DK, diff --git a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py index 90b135743e..3a77164046 100644 --- a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py +++ b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py @@ -31,7 +31,7 @@ def _flatten_kv_cache( stride_vos: tl.constexpr, stride_vod: tl.constexpr, stride_boff, - OUT_SIZE: tl.constexpr, + OUT_SIZE, HEAD_DIM_K: tl.constexpr, HEAD_DIM_V: tl.constexpr, BLOCK_BS: tl.constexpr, @@ -124,7 +124,7 @@ def _flatten_kv_cache_quant( stride_vod: tl.constexpr, stride_boff, quant_policy: tl.constexpr, - OUT_SIZE: tl.constexpr, + OUT_SIZE, HEAD_DIM_K: tl.constexpr, HEAD_DIM_V: tl.constexpr, BLOCK_BS: tl.constexpr, diff --git a/lmdeploy/pytorch/kernels/dlinfer/__init__.py b/lmdeploy/pytorch/kernels/dlinfer/__init__.py index 8f86f0019a..fe82010761 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/__init__.py +++ b/lmdeploy/pytorch/kernels/dlinfer/__init__.py @@ -3,6 +3,7 @@ from .apply_rotary_pos_emb import apply_rotary_pos_emb from .awq_kernels import awq_linear from .fill_kv_cache import fill_kv_cache +from .flash_attention import flash_attention_fwd from .fused_moe import fused_moe from .linear import linear from .moe_gating_topk_softmax import moe_gating_topk_softmax @@ -16,6 +17,7 @@ 'fill_kv_cache', 'fused_moe', 'paged_attention_fwd', + 'flash_attention_fwd', 'linear', 'moe_gating_topk_softmax', 'multinomial_sampling', diff --git a/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py b/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py index fb2eee9d41..63564d7ed8 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py +++ b/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + import dlinfer.ops as ext_ops from torch import Tensor @@ -9,7 +11,16 @@ def fill_kv_cache( key_caches: Tensor, value_caches: Tensor, kv_start_indices: Tensor, + k_scales_zeros: Sequence[Optional[Tensor]], + v_scales_zeros: Sequence[Optional[Tensor]], + quant_bits: int = 0, ): """fill key/value state to cache for paged attention.""" - return ext_ops.fill_kv_cache(key_states, value_states, key_caches, - value_caches, kv_start_indices) + return ext_ops.fill_kv_cache(key_states, + value_states, + key_caches, + value_caches, + kv_start_indices, + k_scales_zeros=k_scales_zeros, + v_scales_zeros=v_scales_zeros, + quant_bits=quant_bits) diff --git a/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py b/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py new file mode 100644 index 0000000000..1788f947ee --- /dev/null +++ b/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +from dlinfer.utils.type_annotation import Tensor + + +def flash_attention_fwd( + query_states: Tensor, + key_states: Tensor, + value_states: Tensor, + attn_output: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_start_loc: Tensor, + kv_seqlens: Tensor, + max_q_seqlen: int = None, + window_size: int = None, + sm_scale: float = None, + logit_softcapping: float = None, + causal: bool = True, +): + num_q_heads = query_states.shape[1] + num_kv_heads = value_states.shape[1] + return ext_ops.prefill_attention( + query_states, + key_states, + value_states, + q_start_loc, + q_seqlens, + max_q_seqlen, + num_q_heads, + num_kv_heads, + attn_mask=None, + softmax_scale=sm_scale, + attn_output=attn_output, + ) diff --git a/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py b/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py index 72bab2d720..275ea65261 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py +++ b/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py @@ -5,12 +5,13 @@ def fused_moe( hidden_states: Tensor, - top_k: int, - topk_ids: Tensor, - topk_weights: Tensor, gate_up_weights: Tensor, down_weights: Tensor, + topk_weights: Tensor, + topk_ids: Tensor, + topk: int, + renormalize: bool, ): - """ascend fused moe.""" - return ext_ops.fused_moe(hidden_states, top_k, topk_ids, topk_weights, - gate_up_weights, down_weights) + """dlinfer fused moe.""" + return ext_ops.fused_moe(hidden_states, gate_up_weights, down_weights, + topk_weights, topk_ids, topk, renormalize) diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index 47bcb0cfff..ded85d476d 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -19,6 +19,9 @@ def prefill_attention( block_size: int, attn_mask: Sequence[Optional[Tensor]], is_unpaged_prefill: Optional[bool], + kv_scales: Optional[Tensor], + kv_zeros: Optional[Tensor], + quant_bits: Optional[int], ) -> Tensor: num_q_heads = query_states.shape[1] num_kv_heads = value_states.shape[1] @@ -53,11 +56,25 @@ def prefill_attention( num_kv_heads, attn_mask, attn_output=attn_output, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, ) -def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, - max_kv_seq_len, block_offsets, block_size): +def paged_token_attention( + q, + k_cache, + v_cache, + attn_output, + kv_seq_len, + max_kv_seq_len, + block_offsets, + block_size, + kv_scales: Optional[Tensor], + kv_zeros: Optional[Tensor], + quant_bits: Optional[int], +): num_q_heads, q_head_dim = q.shape[1:3] num_kv_heads = k_cache.shape[-1] // q_head_dim return ext_ops.paged_decode_attention( @@ -71,6 +88,9 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, num_q_heads, num_kv_heads, attn_output=attn_output, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, ) @@ -91,6 +111,9 @@ def paged_attention_fwd( block_size: int, attn_mask: Sequence[Optional[Tensor]] = (), is_unpaged_prefill: Optional[bool] = None, + kv_scales: Optional[Tensor] = None, + kv_zeros: Optional[Tensor] = None, + quant_bits: Optional[int] = 0, ): if not is_decoding: return prefill_attention( @@ -108,6 +131,9 @@ def paged_attention_fwd( block_size, attn_mask, is_unpaged_prefill, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, ) else: return paged_token_attention( @@ -119,4 +145,7 @@ def paged_attention_fwd( max_kv_seq_len, block_offsets, block_size, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, ) diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index b16a78f1f4..0aaba98c94 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -8,6 +8,7 @@ from torch import Tensor from lmdeploy.messages import GenerationConfig, LogitsProcessor +from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs from lmdeploy.utils import get_logger from .block import LogicalTokenBlocks @@ -205,10 +206,9 @@ def add_sequence( sampling_param: SamplingParam = None, adapter_name: str = None, return_logits: bool = False, - input_embeddings: List[InputEmbeddings] = None, - mrope_position_ids: Tensor = None, - mrope_position_delta: Tensor = None, - cross_attention_states: Tensor = None) -> 'SchedulerSequence': + multimodals: MultiModalInputs = None, + input_embeddings: List[InputEmbeddings] = None + ) -> 'SchedulerSequence': """Add a new message.""" if isinstance(token_ids, Tensor): token_ids = token_ids.numpy() @@ -228,10 +228,8 @@ def add_sequence( adapter_name=adapter_name, arrive_time=time.time(), history_embeddings=HistoryEmbeddings(input_embeddings), + history_multimodals=HistoryMultiModals(multimodals), return_logits=return_logits, - mrope_position_ids=mrope_position_ids, - mrope_position_delta=mrope_position_delta, - cross_attention_states=cross_attention_states, ) self.sequences[seq.seq_id] = seq if self.seq_manager is not None: @@ -361,6 +359,66 @@ def copy(self): return self.clone() +class HistoryMultiModals: + + def __init__(self, multimodals: MultiModalInputs): + if multimodals is None: + multimodals = dict() + self.multimodals = multimodals + + def get_datas(self, start=0, end=-1): + """get multimodals from prompts position [start, end).""" + outs = dict() + test_range = range(start, end) + for modal_type, modal_datas in self.multimodals.items(): + data = [] + for modal_data in modal_datas: + if (modal_data.start not in test_range + and modal_data.end not in test_range): + continue + data.append(modal_data) + if len(data) > 0: + outs[modal_type] = data + return outs + + def add_inputs(self, input_mms: MultiModalInputs): + """add new inputs.""" + for modal_type, vals in input_mms.items(): + if modal_type in self.multimodals: + self.multimodals[modal_type] += vals + else: + self.multimodals[modal_type] = vals + + def empty(self): + if len(self.multimodals) == 0: + return 0 + + return all(len(vals) == 0 for vals in self.multimodals) + + @staticmethod + def update_multimodals(input_mms: MultiModalInputs, prev_len: int): + """update multimodals.""" + for vals in input_mms.values(): + for val in vals: + val.start += prev_len + val.end += prev_len + return input_mms + + def get_encoder_len(self, start=0, end=-1): + """get lens of encoder.""" + test_range = range(start, end) + out_len = 0 + for _, modal_datas in self.multimodals.items(): + for modal_data in modal_datas: + if modal_data.encoder_len is None: + continue + if (modal_data.start not in test_range + and modal_data.end not in test_range): + continue + out_len += modal_data.encoder_len + return out_len + + @dataclass class SchedulerSequence: """Scheduler message.""" @@ -369,6 +427,8 @@ class SchedulerSequence: history_cache: HistoryTokenIds = field(default_factory=HistoryTokenIds) history_embeddings: HistoryEmbeddings = field( default_factory=HistoryEmbeddings) + history_multimodals: HistoryMultiModals = field( + default_factory=HistoryMultiModals) num_new_tokens: int = 0 sampling_param: SamplingParam = field(default_factory=SamplingParam) logical_blocks: LogicalTokenBlocks = field( @@ -382,10 +442,7 @@ class SchedulerSequence: random_offsets: int = 0 _status: MessageStatus = field(default=MessageStatus.WAITING, init=False) num_ignored_history: int = 0 - mrope_position_ids: Optional[Tensor] = None - mrope_position_delta: Optional[int] = None - cross_attention_states: Optional[Tensor] = None - history_cross_kv_seqlens: int = 0 + model_meta: Dict[str, Any] = None def __post_init__(self): """post init.""" @@ -394,6 +451,10 @@ def __post_init__(self): self._num_images: int = len(self.history_embeddings) self._num_token_ids: int = len(self.history_cache) + self._num_history_cross: int = 0 + self._num_cross: int = self.history_multimodals.get_encoder_len( + 0, self._num_token_ids) + @property def block_size(self) -> int: """block size.""" @@ -464,6 +525,16 @@ def num_all_ids(self): """num all tokens.""" return self.history_len + self._num_token_ids + @property + def num_cross(self): + """num cross.""" + return self._num_cross + + @property + def num_history_cross(self): + """num history cross.""" + return self._num_history_cross + @property def num_blocks(self): """num blocks.""" @@ -489,22 +560,22 @@ def num_all_tokens(self): def num_all_cross_tokens(self): """num of all cross tokens.""" - if self.cross_attention_states is None: - self.history_cross_kv_seqlens = 0 - else: - self.history_cross_kv_seqlens = self.cross_attention_states.shape[ - -2] - return self.history_cross_kv_seqlens + return self._num_cross + self._num_history_cross + + def get_input_multimodals(self): + """get input multimodals.""" + start = self.num_history_ids + end = self.num_all_ids + return self.history_multimodals.get_datas(start, end) def update_token_ids(self, token_ids: Tensor, + multimodals: MultiModalInputs = None, embeddings: List[InputEmbeddings] = None, - cross_attention_states: List[Tensor] = None): + model_meta: Dict[str, Any] = None): """Update token ids, old token ids will be added to history.""" - # cross attention - if cross_attention_states is not None: - self.history_cross_kv_seqlens += cross_attention_states.shape[-2] - self.cross_attention_states = cross_attention_states + old_num_history_ids = self._num_history_ids + self._num_history_ids += self._num_token_ids # update history image nums self._num_history_images += self._num_images @@ -516,6 +587,23 @@ def update_token_ids(self, self._num_images = len(new_embeddings) self.history_embeddings.append(new_embeddings) + # update multimodals + if multimodals is not None: + multimodals = HistoryMultiModals.update_multimodals( + multimodals, self.num_all_ids) + self.history_multimodals.add_inputs(multimodals) + + # cross + self._num_history_cross += self._num_cross + if multimodals is not None: + self._num_cross = self.history_multimodals.get_encoder_len( + old_num_history_ids, self._num_history_ids) + else: + self._num_cross = 0 + + if model_meta is not None: + self.model_meta = model_meta + if isinstance(token_ids, Tensor): token_ids = token_ids.numpy() elif not isinstance(token_ids, np.ndarray): @@ -539,3 +627,12 @@ def set_step(self, step: int): self._num_history_ids = step self._num_token_ids = num_all_ids - step self.num_ignored_history = min(step, self.num_ignored_history) + + self.model_meta = None + + # cross + if self.history_multimodals is not None: + self._num_history_cross = self.history_multimodals.get_encoder_len( + 0, self.num_history_ids) + self._num_cross = self.history_multimodals.get_encoder_len( + self._num_history_ids, num_all_ids) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 669625d43d..d10da8557a 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -4,47 +4,19 @@ from typing import Any, Dict, List, Literal import torch +from torch import distributed as dist from lmdeploy.pytorch.backends import get_backend +from lmdeploy.pytorch.config import ModelConfig +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor -@dataclass -class MRopeModelInputs: - """Multimodal rotary position inputs.""" - position_ids: List[torch.LongTensor] = None - deltas: List[torch.LongTensor] = None - - def get_inputs(self, history_lengths: torch.Tensor, - seq_lengths: torch.Tensor): - mrope_position_ids = [] - for (his_len, seq_len, pos_ids, - delta) in zip(history_lengths, seq_lengths, self.position_ids, - self.deltas): - assert pos_ids.dim() == 2, 'invalid mrope_position_ids' - if his_len + seq_len <= pos_ids.shape[1]: - mrope_position_ids.append(pos_ids[:, - his_len:his_len + seq_len]) - else: - mrope_position_ids.append( - torch.tensor([his_len], device=delta.device).expand(3, -1) - + delta) - - mrope_position_ids = torch.cat(mrope_position_ids, dim=-1) - return mrope_position_ids - - def to_device(self, device: str): - """to device.""" - out_dict = dict() - for f in fields(self): - k = f.name - v = getattr(self, k) - if isinstance(v, torch.Tensor): - v = v.to(device) - elif isinstance(v, list): - v = [x.to(device) for x in v] - out_dict[k] = v - - return MRopeModelInputs(**out_dict) +def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'): + """broadcast tensor.""" + if value.device.type == 'meta': + value = torch.empty_like(value, device=device) + dist.broadcast(value, src) + return value @dataclass @@ -56,6 +28,7 @@ class VisionModelInputs: input_embeddings: List[List[torch.Tensor]] = None input_embedding_ranges: List[torch.LongTensor] = None input_embedding_indexing: torch.BoolTensor = None + input_multimodals: List[MultiModalTensor] = None def to_device(self, device: str): """to device.""" @@ -63,12 +36,54 @@ def to_device(self, device: str): for f in fields(self): k = f.name v = getattr(self, k) + if v is None: + continue if isinstance(v, torch.Tensor): v = v.to(device) - elif k == 'input_embedding_ranges' and v is not None: + elif k == 'input_embedding_ranges': v = [e.to(device) for e in v] - elif k == 'input_embeddings' and v is not None: + elif k == 'input_embeddings': v = [[e.to(device) for e in li] for li in v] + elif k == 'input_multimodals': + new_v = [] + for mm_datas in v: + new_mm_datas = dict() + for modal_type, data in mm_datas.items(): + data = [d.to_device(device) for d in data] + new_mm_datas[modal_type] = data + new_v.append(new_mm_datas) + v = new_v + out_dict[k] = v + + return VisionModelInputs(**out_dict) + + def broadcast(self): + """broadcast inputs. + + Do `dist.broadcast_object_list(inputs.to_device('meta'))` + before broadcast tensors. + """ + out_dict = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if v is None: + continue + if isinstance(v, torch.Tensor): + v = _broadcast_tensor(v) + elif k == 'input_embedding_ranges': + v = [_broadcast_tensor(e) for e in v] + elif k == 'input_embeddings': + v = [[_broadcast_tensor(e) for e in li] for li in v] + elif k == 'input_multimodals': + new_v = [] + for mm_datas in v: + new_mm_datas = dict() + for modal_type, data in mm_datas.items(): + data = [d.broadcast() for d in data] + new_mm_datas[modal_type] = data + new_v.append(new_mm_datas) + v = new_v out_dict[k] = v return VisionModelInputs(**out_dict) @@ -119,9 +134,9 @@ class ModelInputs: num_ignored_history: torch.LongTensor local_adapter_ids: torch.LongTensor = None vision_inputs: VisionModelInputs = None - mrope_inputs: MRopeModelInputs = None - cross_attention_states: torch.Tensor = None - history_cross_kv_seqlens: torch.LongTensor = None + cross_length: torch.LongTensor = None + history_cross_length: torch.LongTensor = None + model_metas: List[Dict[str, Any]] = None def update(self, input_ids: torch.LongTensor): """update input ids.""" @@ -132,44 +147,88 @@ def update(self, input_ids: torch.LongTensor): self.input_ids = input_ids return self - def split(self, split_size: int, block_size: int): + def split(self, split_size: int): """split inputs.""" assert len( self.seq_length) == 1, ('Can not perform split on batched input.') - assert split_size % block_size == 0, ( - 'split_size should be multi of block_size.') input_ids = self.input_ids if input_ids.numel() < split_size: return self - num_blocks = split_size // block_size - overlap = (self.history_lengths[0] % block_size != 0) + flatten_mms = [] + vision_inputs = self.vision_inputs + if vision_inputs is not None: + if vision_inputs.input_multimodals is not None: + input_mms = vision_inputs.input_multimodals[0] + + flatten_mms = [] + for k, mms in input_mms.items(): + mms = [(k, mm) for mm in mms] + flatten_mms += mms + + flatten_mms = sorted(flatten_mms, key=lambda mm: mm[1].start) + max_seq_len = self.seq_length[0].item() ret = [] - block_start = 0 - for i in range(0, max_seq_len, split_size): - start = i - end = min(max_seq_len, i + split_size) - block_end = block_start + num_blocks - if overlap: - block_end += 1 - - block_offsets = self.block_offsets + start = 0 + history_cross_length = self.history_cross_length + cross_length = None + if history_cross_length is not None: + cross_length = self.history_cross_length.clone() + while start < max_seq_len: + vision_inputs = None + if len(flatten_mms) > 0: + mm_start = flatten_mms[0][1].start + mm_end = flatten_mms[0][1].end + if mm_start > self.history_lengths + start: + end = min(mm_start - self.history_lengths, + start + split_size) + else: + input_mms = dict() + key, mm = flatten_mms.pop(0) + input_mms.setdefault(key, []) + input_mms[key].append(mm) + end = start + mm.end - mm.start + while len(flatten_mms) > 0: + next_mm = flatten_mms[0] + next_start = next_mm[1].start + next_end = next_mm[1].end + if next_start < mm_end: + key = next_mm[0] + input_mms.setdefault(key, []) + input_mms[key].append(next_mm[1]) + end += max(0, next_end - mm_end) + flatten_mms.pop(0) + + if cross_length is not None: + encoder_len = next_mm[1].encoder_len + if encoder_len is not None: + cross_length += encoder_len + else: + break + vision_inputs = VisionModelInputs( + input_multimodals=[input_mms], ) + else: + end = min(max_seq_len, start + split_size) + inp = ModelInputs( input_ids=self.input_ids[:, start:end], seq_length=input_ids.new_tensor([end - start]), - block_offsets=block_offsets, + block_offsets=self.block_offsets, history_lengths=self.history_lengths + start, is_decoding=self.is_decoding, num_ignored_history=self.num_ignored_history, local_adapter_ids=self.local_adapter_ids, - vision_inputs=self.vision_inputs, - mrope_inputs=self.mrope_inputs, - cross_attention_states=self.cross_attention_states, + vision_inputs=vision_inputs, + model_metas=self.model_metas, + cross_length=cross_length, + history_cross_length=history_cross_length, ) ret.append(inp) - block_start += num_blocks + history_cross_length = cross_length + + start = end return ret @@ -183,8 +242,24 @@ def to_device(self, device: str): v = v.to(device) elif isinstance(v, VisionModelInputs): v = v.to_device(device) - elif isinstance(v, MRopeModelInputs): - v = v.to_device(device) + out_dict[k] = v + + return ModelInputs(**out_dict) + + def broadcast(self): + """broadcast inputs. + + Do `dist.broadcast_object_list(inputs.to_device('meta'))` + before broadcast tensors. + """ + out_dict = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if isinstance(v, torch.Tensor): + v = _broadcast_tensor(v) + elif isinstance(v, VisionModelInputs): + v = v.broadcast() out_dict[k] = v return ModelInputs(**out_dict) @@ -198,6 +273,7 @@ class StepContext: dataclass provide these infos and tools. """ input_ids: torch.LongTensor + model_config: ModelConfig block_offsets: torch.LongTensor position_ids: torch.LongTensor attention_mask: torch.LongTensor @@ -210,13 +286,14 @@ class StepContext: local_adapter_ids: torch.LongTensor = None input_embeddings: torch.Tensor = None input_embedding_indexing: torch.Tensor = None + input_multimodals: List[MultiModalTensor] = None vision_inputs: VisionModelInputs = None - mrope_position_ids: torch.Tensor = None attn_metadata: Any = None - cross_attn_metadata: Any = None - cross_attention_states: torch.Tensor = None + cross_seqlens: torch.LongTensor = None cross_kv_seqlens: torch.LongTensor = None + cross_attn_metadata: Any = None kv_quant_policy: Literal[0, 4, 8] = 0 + model_metas: List[Dict[str, Any]] = None _outputs: Dict = field(default_factory=dict) @@ -224,6 +301,7 @@ class StepContext: def new( cls, inputs: ModelInputs, + model_config: ModelConfig, world_size: int = 1, kv_caches: List = None, kv_quant_policy: Literal[0, 4, 8] = 0, @@ -239,24 +317,21 @@ def new( history_seqlens = inputs.history_lengths device = q_seqlens.device + input_multimodals = None + if inputs.vision_inputs is not None: + input_multimodals = inputs.vision_inputs.input_multimodals + # for vlm input_embeddings, input_embedding_indexing = None, None if (inputs.vision_inputs is not None and inputs.vision_inputs.input_embeddings is not None): input_embeddings, input_embedding_indexing = \ inputs.vision_inputs.get_inputs(history_seqlens, q_seqlens) - # for mrope - mrope_position_ids = None - if inputs.mrope_inputs is not None: - mrope_position_ids = inputs.mrope_inputs.get_inputs( - history_seqlens, q_seqlens) # kv_seqlens - cross_attention_states = inputs.cross_attention_states if inputs.is_decoding: attention_mask = torch.ones_like(q_seqlens)[:, None] - position_ids = history_seqlens.unsqueeze(-1) - cross_attention_states = None + position_ids = history_seqlens.unsqueeze(-1).clone() else: max_q_seqlen = q_seqlens.max().item() mask_range = torch.arange(max_q_seqlen, device=device)[None, :] @@ -265,6 +340,13 @@ def new( position_ids += history_seqlens.unsqueeze(-1) q_start_loc = q_seqlens.cumsum(0) - q_seqlens + # cross + cross_seqlens = inputs.cross_length + cross_kv_seqlens = None + if inputs.cross_length is not None: + cross_kv_seqlens = (inputs.cross_length + + inputs.history_cross_length) + # position ids 1d position_ids = cls.get_position_ids_1d(position_ids, q_seqlens)[None] # seq_len + history_length @@ -273,10 +355,12 @@ def new( ret = StepContext( input_ids=inputs.input_ids, + model_config=model_config, block_offsets=inputs.block_offsets, position_ids=position_ids, input_embeddings=input_embeddings, input_embedding_indexing=input_embedding_indexing, + input_multimodals=input_multimodals, attention_mask=attention_mask, q_seqlens=q_seqlens, kv_seqlens=kv_seqlens, @@ -286,10 +370,10 @@ def new( world_size=world_size, local_adapter_ids=inputs.local_adapter_ids, vision_inputs=inputs.vision_inputs, - mrope_position_ids=mrope_position_ids, - cross_attention_states=cross_attention_states, - cross_kv_seqlens=inputs.history_cross_kv_seqlens, kv_quant_policy=kv_quant_policy, + model_metas=inputs.model_metas, + cross_seqlens=cross_seqlens, + cross_kv_seqlens=cross_kv_seqlens, ) ret = get_backend().update_step_context(ret) @@ -318,6 +402,7 @@ def __init__(self): @staticmethod def build_context( inputs: ModelInputs, + model_config: ModelConfig, world_size: int = 1, kv_caches: List = None, kv_quant_policy: Literal[0, 4, 8] = 0, @@ -325,6 +410,7 @@ def build_context( """build context.""" return StepContext.new( inputs, + model_config, world_size, kv_caches, kv_quant_policy, diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py index 8d7a21a0a6..73f64d277c 100644 --- a/lmdeploy/pytorch/models/chatglm2.py +++ b/lmdeploy/pytorch/models/chatglm2.py @@ -1,101 +1,29 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn +from torch.nn import functional as F from transformers.configuration_utils import PretrainedConfig +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding) -from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, + build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixin LANGUAGE_TOKEN_TYPE = 0 VISION_TOKEN_TYPE = 1 -def get_vision_expert_mask(token_type_ids: torch.LongTensor): - vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool) - vision_token_mask[:, :-1] = (token_type_ids[:, :-1] - == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] - == VISION_TOKEN_TYPE) - language_token_mask = ~vision_token_mask - return vision_token_mask, language_token_mask - - -def build_position_ids(x: torch.BoolTensor) -> torch.LongTensor: - tmp = x.clone() - # image boi eoi token as LANGUAGE_TOKEN_TYPE - is_boi_eoi = torch.zeros_like(x, dtype=torch.bool) - is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & ( - tmp[:, :-1] == LANGUAGE_TOKEN_TYPE) - is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE) - is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & ( - tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) - is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE) - tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE - # final position ids - y = torch.zeros_like(x, dtype=torch.long) - y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ( - (tmp[:, 1:] == VISION_TOKEN_TYPE) & - (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)) - y = y.cumsum(dim=-1) - return y - - -def _get_cogvlm_position_ids(context): - """get cogvlm position_ids.""" - q_seqlens = context.q_seqlens - history_lengths = context.kv_seqlens - q_seqlens - vision_input_info = context.vision_inputs - position_id_offsets = (vision_input_info.history_image_token_lengths - - vision_input_info.history_image_nums * 3) - lang_ids = None - vis_ids = None - if context.is_decoding: - position_ids = history_lengths - position_id_offsets - else: - if vision_input_info.input_embeddings is not None and len( - vision_input_info.input_embeddings) > 0: - starts = history_lengths - vision_input_info.history_lengths - ends = starts + q_seqlens - token_type_ids = vision_input_info.input_embedding_indexing.to( - torch.int) - history_position_lengths = (vision_input_info.history_lengths - - position_id_offsets) - position_ids_all = (history_position_lengths[:, None] + - build_position_ids(token_type_ids)) - position_ids = torch.cat([ - pids[s:e] - for (pids, s, e) in zip(position_ids_all, starts, ends) - ]) - vision_token_mask_all, _ = get_vision_expert_mask(token_type_ids) - vision_token_mask = torch.cat([ - masks[s:e] - for (masks, s, e) in zip(vision_token_mask_all, starts, ends) - ]) - mask_indexing = torch.arange(vision_token_mask.shape[-1], - device=vision_token_mask.device) - vis_ids = mask_indexing[vision_token_mask] - lang_ids = mask_indexing[~vision_token_mask] - - else: - position_ids = context.attention_mask.long().cumsum(-1) - 1 - position_ids += (history_lengths - - position_id_offsets).unsqueeze(-1) - device = position_ids.device - position_ids_1d = [ - ids[:l] for ids, l in zip(position_ids.cpu(), q_seqlens.cpu()) - ] - position_ids = torch.cat(position_ids_1d).to(device) - - return position_ids, lang_ids, vis_ids - - class SelfAttention(torch.nn.Module): """Parallel self-attention layer abstract class. @@ -112,11 +40,10 @@ def __init__(self, self.projection_size = config.kv_channels * config.num_attention_heads self.num_attention_heads = config.num_attention_heads - self.num_kv_heads = self.num_attention_heads + self.num_kv_heads = config.num_key_value_heads self.head_size = (self.projection_size // config.num_attention_heads) - self.multi_query_attention = config.multi_query_attention - if self.multi_query_attention: - self.num_kv_heads = config.multi_query_group_num + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) self.query_key_value = build_qkv_proj( config.hidden_size, num_q_heads=self.num_attention_heads, @@ -126,7 +53,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) # apply rotary self.apply_rotary_pos_emb = ApplyRotaryEmb() @@ -410,6 +337,286 @@ def forward(self, input_ids): return embeddings +class PatchEmbedding(nn.Module): + """vision embedding.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.proj = nn.Conv2d(config.in_channels, + config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + dtype=dtype, + device=device) + self.cls_embedding = nn.Parameter( + torch.empty(1, config.hidden_size, dtype=dtype, device=device)) + self.position_embedding = nn.Embedding(config.num_positions, + config.hidden_size, + dtype=dtype, + device=device) + + def forward(self, images): + """forward.""" + x = self.proj(images) + x = x.flatten(2).transpose(1, 2) + cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x += self.position_embedding.weight.unsqueeze(0) + return x + + +class EVA2CLIPAttention(nn.Module): + """vision attention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + hidden_size = config.hidden_size + num_heads = config.num_heads + head_dim = config.hidden_size // config.num_heads + self.scale = head_dim**-0.5 + + # packed qkv + self.query_key_value = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # o_proj + self.dense = build_rowwise_linear(hidden_size, + hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """forward.""" + # qkv proj + qkv_states = self.query_key_value(hidden_states) + q, k, v = self.query_key_value.split_qkv(qkv_states) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + + # o proj + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(-2, -1) + attn_output = self.dense(attn_output) + return attn_output + + +class EVA2CLIPMLP(nn.Module): + """vision MLP.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + + # gate up + quantization_config = getattr(config, 'quantization_config', None) + self.fc1 = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + if config.hidden_act in [ + 'gelu', 'gelu_fast', 'quick_gelu', 'gelu_python' + ]: + self.activation_fn = nn.GELU() + else: + self.activation_fn = ACT2FN[config.hidden_act] + + # down + self.fc2 = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """forward.""" + x = self.fc1(x) + x = self.activation_fn(x) + x = self.fc2(x) + return x + + +class EVA2CLIPTransformerLayer(nn.Module): + """vision trans layer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + self.attention = EVA2CLIPAttention(config, dtype=dtype, device=device) + self.mlp = EVA2CLIPMLP(config, dtype=dtype, device=device) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + + def forward(self, hidden_states): + """forward.""" + attention_input = hidden_states + attention_output = self.input_layernorm( + self.attention(attention_input)) + hidden_states = attention_input + attention_output + mlp_input = hidden_states + mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) + output = mlp_input + mlp_output + return output + + +class EVA2CLIPTransformer(nn.Module): + """vision transformer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layers = nn.ModuleList([ + EVA2CLIPTransformerLayer(config, dtype=dtype, device=device) + for _ in range(config.num_hidden_layers) + ]) + + def forward(self, hidden_states): + """forward.""" + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class GLU(nn.Module): + """GLU.""" + + def __init__(self, + config: PretrainedConfig, + in_features: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.linear_proj = nn.Linear(in_features, + config.hidden_size, + bias=False, + dtype=dtype, + device=device) + self.norm1 = nn.LayerNorm(config.hidden_size, + dtype=dtype, + device=device) + self.act1 = nn.GELU() + self.act2 = nn.functional.silu + self.dense_h_to_4h = nn.Linear(config.hidden_size, + config.ffn_hidden_size, + bias=False, + dtype=dtype, + device=device) + self.gate_proj = nn.Linear(config.hidden_size, + config.ffn_hidden_size, + bias=False, + dtype=dtype, + device=device) + self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, + config.hidden_size, + bias=False, + dtype=dtype, + device=device) + + def forward(self, x): + x = self.linear_proj(x) + x = self.act1(self.norm1(x)) + x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x) + x = self.dense_4h_to_h(x) + return x + + +class EVA2CLIPModel(nn.Module): + """vision model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from argparse import Namespace + vision_config = Namespace(**config.vision_config) + + self.patch_embedding = PatchEmbedding(vision_config, + dtype=dtype, + device=device) + self.transformer = EVA2CLIPTransformer(vision_config, + dtype=dtype, + device=device) + self.linear_proj = GLU(config, + in_features=config.hidden_size, + dtype=dtype, + device=device) + self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, + out_channels=config.hidden_size, + kernel_size=2, + stride=2, + dtype=dtype, + device=device) + self.boi = nn.Parameter( + torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device)) + self.eoi = nn.Parameter( + torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device)) + self.scaling_factor = vision_config.scaling_factor + + def forward(self, images): + """forward.""" + x = self.patch_embedding(images) + x = self.transformer(x) + + x = x[:, 1:] + + b, s, h = x.shape + grid_size = int(s**0.5) + x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2) + x = self.conv(x) + + x = x.flatten(2).transpose(1, 2) + x = self.linear_proj(x) + boi = self.boi.expand(x.shape[0], -1, -1) + eoi = self.eoi.expand(x.shape[0], -1, -1) + x = torch.cat((boi, x, eoi), dim=1) + x = x / self.scaling_factor + return x + + class ChatGLMModel(nn.Module): def __init__(self, @@ -442,19 +649,32 @@ def __init__(self, dtype=dtype, device=device) + self.vision = None + if hasattr(config, 'vision_config'): + self.vision = EVA2CLIPModel(config, dtype=dtype, device=device) + def forward( self, input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, attn_metadata: Any = None, + images: torch.Tensor = None, + image_mask: torch.Tensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, ): """forward.""" # token embedding if inputs_embeds is None: + images_features = None + if images is not None: + images_features = self.vision(images) + images_features = images_features.flatten(0, 1)[None] inputs_embeds = self.embedding(input_ids) + if images is not None: + inputs_embeds.masked_scatter_(image_mask[..., None], + images_features) hidden_states = inputs_embeds @@ -477,7 +697,8 @@ def get_input_embeddings(self): return self.embedding -class ChatGLMForConditionalGeneration(nn.Module, CudaGraphMixin): +class ChatGLMForConditionalGeneration(nn.Module, DeployModelMixin, + CudaGraphMixin): """rewrote model of LlamaForCausalLM.""" def __init__(self, @@ -491,12 +712,16 @@ def __init__(self, # build Model self.transformer = ChatGLMModel(config, dtype=dtype, device=device) + self.input_processor = ChatGLMInputProcessor(self.config, dtype) + def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, + images: torch.Tensor = None, + image_mask: torch.Tensor = None, inputs_embeds: torch.Tensor = None, **kwargs, ): @@ -506,6 +731,8 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + images=images, + image_mask=image_mask, inputs_embeds=inputs_embeds, ) return hidden_states @@ -529,8 +756,23 @@ def prepare_inputs_for_generation( input_ids = context.input_ids position_ids = context.position_ids attn_metadata = context.attn_metadata - if context.vision_inputs is not None: - position_ids = _get_cogvlm_position_ids(context)[0][None] + + images = None + image_mask = None + if context.input_multimodals is not None: + images = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + images = [data for im_data in images for data in im_data] + if len(images) != 0: + image_token_id = images[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + images = torch.stack([data.data for data in images]) + else: + images = None + image_mask = None # process vision embeddings vision_embeddings = context.input_embeddings @@ -548,9 +790,92 @@ def prepare_inputs_for_generation( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + images=images, + image_mask=image_mask, inputs_embeds=inputs_embeds, ) + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """update model meta.""" + model_metas = context.model_metas + if not hasattr(self.config, 'vision_config'): + return model_metas + + input_multimodals = context.input_multimodals + if input_multimodals is None: + input_imgs = [[] for _ in model_metas] + else: + input_imgs = [] + for mm in input_multimodals: + if mm is None: + input_imgs.append([]) + else: + input_imgs.append(mm.get('image', [])) + + config = self.config + image_size: int = config.vision_config['image_size'] + patch_size: int = config.vision_config['patch_size'] + vision_token_num = ((image_size // patch_size // 2) * + (image_size // patch_size // 2) + 2) + num_pad = vision_token_num - 3 + + batched_num_img_tokens = [] + new_model_metas = [] + for meta, imgs in zip(model_metas, input_imgs): + if meta is None: + num_img_tokens = 0 + else: + num_img_tokens = meta.get('num_img_tokens', 0) + + batched_num_img_tokens.append(num_img_tokens) + + num_img_tokens += num_pad * len(imgs) + new_model_metas.append(dict(num_img_tokens=num_img_tokens)) + + # prepare cogvlm position_ids + q_seqlens = context.q_seqlens + position_ids = context.position_ids + + if context.is_decoding or all(len(imgs) == 0 for imgs in input_imgs): + num_img_tokens = torch.tensor(batched_num_img_tokens, + device=position_ids.device) + position_ids -= num_img_tokens[None] + else: + batched_position_ids = position_ids[0].split(q_seqlens) + for pos_ids, num_img_tok, imgs in zip(batched_position_ids, + batched_num_img_tokens, + input_imgs): + pos_ids -= num_img_tok + if len(imgs) == 0: + continue + + seq_len = pos_ids.size(0) + start = pos_ids[0].cpu().item() + new_pos_ids = [] + + imgs = sorted(imgs, key=lambda img: img.start) + for img in imgs: + img_pad_pos = img.start + 1 - num_img_tok + num_pad = img.end - img.start - 2 + new_pos_ids += list(range(start, img_pad_pos)) + new_pos_ids += [img_pad_pos] * num_pad + start = img_pad_pos + 1 + num_img_tok += num_pad + + remain = seq_len - len(new_pos_ids) + new_pos_ids += list(range(start, start + remain)) + + new_pos_ids = pos_ids.new_tensor(new_pos_ids) + pos_ids[:] = new_pos_ids + + position_ids = torch.cat(batched_position_ids)[None] + context.position_ids = position_ids + + return new_model_metas + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" # modify from vllm @@ -558,7 +883,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if 'transformer.vision' in name: + if '.query_key_value' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) continue + if 'rotary_pos_emb.inv_freq' in name: continue if ('rotary_pos_emb.cos_cached' in name @@ -581,3 +916,53 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class ChatGLMInputProcessor(BaseModelInputProcessor): + """input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + + if hasattr(config, 'vision_config'): + vision_config = config.vision_config + self.image_size = vision_config['image_size'] + self.patch_size = vision_config['patch_size'] + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + self.vision_token_num = self.num_patches // 4 + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + offset = input_mm['offset'] + num_pad = input_mm['image_tokens'] + image_token_id = input_mm.get('image_token_id', 0) + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor( + data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index 6caf10df00..c460b8e44f 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -1,20 +1,27 @@ # Copyright (c) OpenMMLab. All rights reserved. +from argparse import Namespace from typing import Any, Iterable, List, Optional, Tuple import torch import torch.distributed as dist +import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.distributed import get_world_rank +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding) -from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, + build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixin class VisionExpertAttention(nn.Module): @@ -28,8 +35,9 @@ def __init__(self, is_cogvlm2 = hasattr(config, 'num_multi_query_heads') quantization_config = getattr(config, 'quantization_config', None) num_heads = config.num_attention_heads - num_key_value_heads = getattr(config, 'num_multi_query_heads', - num_heads) + num_key_value_heads = getattr(config, 'num_key_value_heads', num_heads) + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) hidden_size = config.hidden_size head_dim = getattr(config, 'head_dim', hidden_size // num_heads) self.hidden_size = hidden_size @@ -46,7 +54,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) self.language_expert_query_key_value = build_qkv_proj( hidden_size, num_q_heads=num_heads, @@ -56,7 +64,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) # rotary embedding self.apply_rotary_pos_emb = ApplyRotaryEmb() @@ -322,6 +330,283 @@ def forward( return outputs +class PatchEmbedding(nn.Module): + """vision embedding.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.proj = nn.Conv2d(config.in_channels, + config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + dtype=dtype, + device=device) + self.cls_embedding = nn.Parameter( + torch.empty(1, config.hidden_size, dtype=dtype, device=device)) + self.position_embedding = nn.Embedding(config.num_positions, + config.hidden_size, + dtype=dtype, + device=device) + + def forward(self, images): + """forward.""" + x = self.proj(images) + x = x.flatten(2).transpose(1, 2) + cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x += self.position_embedding.weight.unsqueeze(0) + return x + + +class EVA2CLIPAttention(nn.Module): + """vision attention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + hidden_size = config.hidden_size + num_heads = config.num_heads + head_dim = config.hidden_size // config.num_heads + self.scale = head_dim**-0.5 + + # packed qkv + self.query_key_value = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # o_proj + self.dense = build_rowwise_linear(hidden_size, + hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """forward.""" + # qkv proj + qkv_states = self.query_key_value(hidden_states) + q, k, v = self.query_key_value.split_qkv(qkv_states) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + + # o proj + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(-2, -1) + attn_output = self.dense(attn_output) + return attn_output + + +class EVA2CLIPMLP(nn.Module): + """vision MLP.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + + # gate up + quantization_config = getattr(config, 'quantization_config', None) + self.fc1 = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + if config.hidden_act in [ + 'gelu', 'gelu_fast', 'quick_gelu', 'gelu_python' + ]: + self.activation_fn = nn.GELU() + else: + self.activation_fn = ACT2FN[config.hidden_act] + + # down + self.fc2 = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """forward.""" + x = self.fc1(x) + x = self.activation_fn(x) + x = self.fc2(x) + return x + + +class EVA2CLIPTransformerLayer(nn.Module): + """vision trans layer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + self.attention = EVA2CLIPAttention(config, dtype=dtype, device=device) + self.mlp = EVA2CLIPMLP(config, dtype=dtype, device=device) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + + def forward(self, hidden_states): + """forward.""" + attention_input = hidden_states + attention_output = self.input_layernorm( + self.attention(attention_input)) + hidden_states = attention_input + attention_output + mlp_input = hidden_states + mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) + output = mlp_input + mlp_output + return output + + +class EVA2CLIPTransformer(nn.Module): + """vision transformer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layers = nn.ModuleList([ + EVA2CLIPTransformerLayer(config, dtype=dtype, device=device) + for _ in range(config.num_hidden_layers) + ]) + + def forward(self, hidden_states): + """forward.""" + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class GLU(nn.Module): + """GLU.""" + + def __init__(self, + config: PretrainedConfig, + in_features: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.linear_proj = nn.Linear(in_features, + config.hidden_size, + bias=False, + dtype=dtype, + device=device) + self.norm1 = nn.LayerNorm(config.hidden_size, + dtype=dtype, + device=device) + self.act1 = nn.GELU() + self.act2 = nn.functional.silu + self.dense_h_to_4h = nn.Linear(config.hidden_size, + config.intermediate_size, + bias=False, + dtype=dtype, + device=device) + self.gate_proj = nn.Linear(config.hidden_size, + config.intermediate_size, + bias=False, + dtype=dtype, + device=device) + self.dense_4h_to_h = nn.Linear(config.intermediate_size, + config.hidden_size, + bias=False, + dtype=dtype, + device=device) + + def forward(self, x): + x = self.linear_proj(x) + x = self.act1(self.norm1(x)) + x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x) + x = self.dense_4h_to_h(x) + return x + + +class EVA2CLIPModel(nn.Module): + """vision model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + vision_config = Namespace(**config.vision_config) + + self.patch_embedding = PatchEmbedding(vision_config, + dtype=dtype, + device=device) + self.transformer = EVA2CLIPTransformer(vision_config, + dtype=dtype, + device=device) + self.linear_proj = GLU(config, + in_features=vision_config.hidden_size, + dtype=dtype, + device=device) + self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, + out_channels=vision_config.hidden_size, + kernel_size=2, + stride=2, + dtype=dtype, + device=device) + self.boi = nn.Parameter( + torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device)) + self.eoi = nn.Parameter( + torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device)) + + def forward(self, images): + """forward.""" + x = self.patch_embedding(images) + x = self.transformer(x) + + x = x[:, 1:] + + b, s, h = x.shape + grid_size = int(s**0.5) + x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2) + x = self.conv(x) + + x = x.flatten(2).transpose(1, 2) + x = self.linear_proj(x) + boi = self.boi.expand(x.shape[0], -1, -1) + eoi = self.eoi.expand(x.shape[0], -1, -1) + x = torch.cat((boi, x, eoi), dim=1) + return x + + class CogVLMModel(nn.Module): """model.""" @@ -353,6 +638,9 @@ def __init__(self, dtype=dtype, device=device) + # vision model + self.vision = EVA2CLIPModel(config, dtype=dtype, device=device) + # build rotary embedding emb_type = RopeType.LinearScaling rope_dim = config.hidden_size // config.num_attention_heads @@ -371,6 +659,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, attn_metadata: Any = None, + images: torch.Tensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, lang_ids: torch.LongTensor = None, vision_ids: torch.LongTensor = None, @@ -379,7 +668,12 @@ def forward( # token embedding if inputs_embeds is None: + if images is not None: + images_features = self.vision(images) + inputs_embeds = self.embed_tokens(input_ids) + if vision_ids is not None: + inputs_embeds[0, vision_ids] = images_features.flatten(0, 1) hidden_states = inputs_embeds @@ -416,85 +710,7 @@ def get_input_embeddings(self): VISION_TOKEN_TYPE = 1 -def get_vision_expert_mask(token_type_ids: torch.LongTensor): - vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool) - vision_token_mask[:, :-1] = (token_type_ids[:, :-1] - == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] - == VISION_TOKEN_TYPE) - language_token_mask = ~vision_token_mask - return vision_token_mask, language_token_mask - - -def build_position_ids(x: torch.BoolTensor) -> torch.LongTensor: - tmp = x.clone() - # image boi eoi token as LANGUAGE_TOKEN_TYPE - is_boi_eoi = torch.zeros_like(x, dtype=torch.bool) - is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & ( - tmp[:, :-1] == LANGUAGE_TOKEN_TYPE) - is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE) - is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & ( - tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) - is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE) - tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE - # final position ids - y = torch.zeros_like(x, dtype=torch.long) - y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ( - (tmp[:, 1:] == VISION_TOKEN_TYPE) & - (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)) - y = y.cumsum(dim=-1) - return y - - -def _get_cogvlm_position_ids(context): - """get cogvlm position_ids.""" - q_seqlens = context.q_seqlens - history_lengths = context.kv_seqlens - q_seqlens - vision_input_info = context.vision_inputs - position_id_offsets = (vision_input_info.history_image_token_lengths - - vision_input_info.history_image_nums * 3) - lang_ids = None - vis_ids = None - if context.is_decoding: - position_ids = history_lengths - position_id_offsets - else: - if vision_input_info.input_embeddings is not None and len( - vision_input_info.input_embeddings) > 0: - starts = history_lengths - vision_input_info.history_lengths - ends = starts + q_seqlens - token_type_ids = vision_input_info.input_embedding_indexing.to( - torch.int) - history_position_lengths = (vision_input_info.history_lengths - - position_id_offsets) - position_ids_all = (history_position_lengths[:, None] + - build_position_ids(token_type_ids)) - position_ids = torch.cat([ - pids[s:e] - for (pids, s, e) in zip(position_ids_all, starts, ends) - ]) - vision_token_mask_all, _ = get_vision_expert_mask(token_type_ids) - vision_token_mask = torch.cat([ - masks[s:e] - for (masks, s, e) in zip(vision_token_mask_all, starts, ends) - ]) - mask_indexing = torch.arange(vision_token_mask.shape[-1], - device=vision_token_mask.device) - vis_ids = mask_indexing[vision_token_mask] - lang_ids = mask_indexing[~vision_token_mask] - - else: - position_ids = context.attention_mask.long().cumsum(-1) - 1 - position_ids += (history_lengths - - position_id_offsets).unsqueeze(-1) - device = position_ids.device - position_ids_1d = [ - ids[:l] for ids, l in zip(position_ids.cpu(), q_seqlens.cpu()) - ] - position_ids = torch.cat(position_ids_1d).to(device) - - return position_ids, lang_ids, vis_ids - - -class CogVLMForCausalLM(nn.Module, CudaGraphMixin): +class CogVLMForCausalLM(nn.Module, CudaGraphMixin, DeployModelMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -512,6 +728,8 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr + # preprocessor + self.input_processor = CogVLMInputProcessor(self.config, dtype) # build model self.model = CogVLMModel(config, dtype=dtype, device=device) # build lm_head @@ -527,6 +745,7 @@ def forward( position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, + images: torch.Tensor = None, inputs_embeds: torch.Tensor = None, lang_ids: torch.LongTensor = None, vision_ids: torch.LongTensor = None, @@ -538,6 +757,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + images=images, inputs_embeds=inputs_embeds, lang_ids=lang_ids, vision_ids=vision_ids, @@ -561,8 +781,36 @@ def prepare_inputs_for_generation( """prepare input.""" # get input_ids, position_ids and attention metadatas input_ids = context.input_ids - position_ids, lang_ids, vis_ids = _get_cogvlm_position_ids(context) - position_ids = position_ids[None] + + # position_ids, lang_ids, vis_ids = _get_cogvlm_position_ids(context) + position_ids = context.position_ids + lang_ids = None + vis_ids = None + + # vision inputs + images = None + if context.input_multimodals is not None: + images = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + images = [data for im_data in images for data in im_data] + if len(images) == 0: + images = None + + if images is not None: + image_token_id = images[0].meta['image_token_id'] + vis_mask = input_ids[0] == image_token_id + images = torch.stack([data.data for data in images]) + + # get lang_ids + vis_range = torch.arange(0, + input_ids.size(-1), + device=input_ids.device) + vis_ids = vis_range[vis_mask] + lang_ids = vis_range[~vis_mask] + attn_metadata = context.attn_metadata # process vision embeddings @@ -581,6 +829,7 @@ def prepare_inputs_for_generation( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + images=images, inputs_embeds=inputs_embeds, lang_ids=lang_ids, vision_ids=vis_ids, @@ -597,8 +846,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - if 'model.vision' in name: - continue if 'rotary_emb.inv_freq' in name: continue if ('rotary_emb.cos_cached' in name @@ -607,6 +854,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if self.config.tie_word_embeddings and 'lm_head.weight' in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: + if '.vision.' in name: + continue if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -620,6 +869,136 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): load_weight(param, q, shard_id='q') load_weight(param, k, shard_id='k') load_weight(param, v, shard_id='v') + elif '.query_key_value' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') else: param = params_dict[name] load_weight(param, loaded_weight) + + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """update model meta.""" + model_metas = context.model_metas + input_multimodals = context.input_multimodals + if input_multimodals is None: + input_imgs = [[] for _ in model_metas] + else: + input_imgs = [] + for mm in input_multimodals: + if mm is None: + input_imgs.append([]) + else: + input_imgs.append(mm.get('image', [])) + + config = self.config + image_size: int = config.vision_config['image_size'] + patch_size: int = config.vision_config['patch_size'] + vision_token_num = ((image_size // patch_size // 2) * + (image_size // patch_size // 2) + 2) + num_pad = vision_token_num - 3 + + batched_num_img_tokens = [] + new_model_metas = [] + for meta, imgs in zip(model_metas, input_imgs): + if meta is None: + num_img_tokens = 0 + else: + num_img_tokens = meta.get('num_img_tokens', 0) + + batched_num_img_tokens.append(num_img_tokens) + + num_img_tokens += num_pad * len(imgs) + new_model_metas.append(dict(num_img_tokens=num_img_tokens)) + + # prepare cogvlm position_ids + q_seqlens = context.q_seqlens + position_ids = context.position_ids + + if context.is_decoding or all(len(imgs) == 0 for imgs in input_imgs): + num_img_tokens = torch.tensor(batched_num_img_tokens, + device=position_ids.device) + position_ids -= num_img_tokens[None] + else: + batched_position_ids = position_ids[0].split(q_seqlens) + for pos_ids, num_img_tok, imgs in zip(batched_position_ids, + batched_num_img_tokens, + input_imgs): + pos_ids -= num_img_tok + if len(imgs) == 0: + continue + + seq_len = pos_ids.size(0) + start = pos_ids[0].cpu().item() + new_pos_ids = [] + + imgs = sorted(imgs, key=lambda img: img.start) + for img in imgs: + img_pad_pos = img.start + 1 - num_img_tok + num_pad = img.end - img.start - 2 + new_pos_ids += list(range(start, img_pad_pos)) + new_pos_ids += [img_pad_pos] * num_pad + start = img_pad_pos + 1 + num_img_tok += num_pad + + remain = seq_len - len(new_pos_ids) + new_pos_ids += list(range(start, start + remain)) + + new_pos_ids = pos_ids.new_tensor(new_pos_ids) + pos_ids[:] = new_pos_ids + + position_ids = torch.cat(batched_position_ids)[None] + context.position_ids = position_ids + + return new_model_metas + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class CogVLMInputProcessor(BaseModelInputProcessor): + """input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + image_size: int = config.vision_config['image_size'] + patch_size: int = config.vision_config['patch_size'] + self.vision_token_num = ((image_size // patch_size // 2) * + (image_size // patch_size // 2) + 2) + + def preprocess_input(self, + input_ids: List[int], + input_multimodals=None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + offset = input_mm['offset'] + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor( + data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index 34debae229..66f68d90e5 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -90,6 +90,9 @@ def __init__(self, self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) + num_key_value_heads = getattr(config, 'num_key_value_heads', 1) if self.q_lora_rank is None: self.q_proj = build_colwise_linear( @@ -157,10 +160,9 @@ def __init__(self, self.num_heads, config.kv_lora_rank + self.qk_rope_head_dim, scale=self.softmax_scale, - num_kv_heads=1, + num_kv_heads=num_key_value_heads, v_head_size=config.kv_lora_rank, - replicate_kv=True, - ) + num_replicate_kv_heads=num_replicate_kv_heads) self.vc = DeepseekV2BMM(self.num_heads, config.kv_lora_rank, diff --git a/lmdeploy/pytorch/models/falcon.py b/lmdeploy/pytorch/models/falcon.py index 8f8659dc5e..2d2edb9f49 100644 --- a/lmdeploy/pytorch/models/falcon.py +++ b/lmdeploy/pytorch/models/falcon.py @@ -31,34 +31,31 @@ def __init__(self, self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads - self.num_kv_heads = self.num_attention_heads + self.num_kv_heads = getattr(config, 'num_kv_heads', + config.num_attention_heads) + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) self.head_size = (self.hidden_size // config.num_attention_heads) - self.multi_query_attention = config.multi_query - if self.multi_query_attention: - self.num_kv_heads = 1 self.query_key_value = build_qkv_proj( config.hidden_size, num_q_heads=self.num_attention_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, bias=config.bias, - replicate_kv=self.multi_query_attention, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) # apply rotary self.apply_rotary_pos_emb = ApplyRotaryEmb() self.rotary = config.rotary # attention - self.attn_fwd = Attention( - self.num_attention_heads, - self.head_size, - num_kv_heads=self.num_kv_heads, - alibi=config.alibi, - ) + self.attn_fwd = Attention(self.num_attention_heads, + self.head_size, + num_kv_heads=self.num_kv_heads, + alibi=config.alibi) # o_proj self.dense = build_rowwise_linear(self.hidden_size, diff --git a/lmdeploy/pytorch/models/gemma.py b/lmdeploy/pytorch/models/gemma.py index ca36f15651..86be85669e 100644 --- a/lmdeploy/pytorch/models/gemma.py +++ b/lmdeploy/pytorch/models/gemma.py @@ -31,7 +31,8 @@ def __init__(self, num_key_value_heads = config.num_key_value_heads hidden_size = config.hidden_size head_dim = config.head_dim - + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) # packed qkv self.qkv_proj = build_qkv_proj( hidden_size, @@ -42,7 +43,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) # rotary embedding self.apply_rotary_pos_emb = ApplyRotaryEmb() diff --git a/lmdeploy/pytorch/models/internlm.py b/lmdeploy/pytorch/models/internlm.py index 99c622e4ac..fdee716b4a 100644 --- a/lmdeploy/pytorch/models/internlm.py +++ b/lmdeploy/pytorch/models/internlm.py @@ -28,7 +28,8 @@ def __init__(self, num_key_value_heads = config.num_key_value_heads hidden_size = config.hidden_size head_dim = getattr(config, 'head_dim', hidden_size // num_heads) - + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) # packed qkv self.qkv_proj = build_qkv_proj( hidden_size, @@ -39,7 +40,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) # rotary embedding self.apply_rotary_pos_emb = ApplyRotaryEmb() diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index 6cbc2ccff3..db246331a1 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -28,7 +28,8 @@ def __init__(self, num_key_value_heads = config.num_key_value_heads hidden_size = config.hidden_size head_dim = hidden_size // num_heads - + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) # packed qkv self.wqkv = build_qkv_proj( hidden_size, @@ -39,6 +40,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, + num_replicate_kv_heads=num_replicate_kv_heads, ) # rotary embedding @@ -395,6 +397,32 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, ) + def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], + adapter_id: int): + """load lora weights.""" + + from lmdeploy.pytorch.adapter.adapter import load_lora_weights + + num_heads = self.config.num_attention_heads + num_key_value_heads = self.config.num_key_value_heads + hidden_size = self.config.hidden_size + head_dim = hidden_size // num_heads + group_size = num_heads // num_key_value_heads + + def _rearange_wqkv(weights): + for name, loaded_weight in weights: + if 'wqkv.lora_B' in name: + loaded_weight = loaded_weight.unflatten( + 0, (-1, 2 + group_size, head_dim)) + q = loaded_weight[:, :-2].flatten(0, 2) + k = loaded_weight[:, -2].flatten(0, 1) + v = loaded_weight[:, -1].flatten(0, 1) + loaded_weight = torch.cat([q, k, v], dim=0) + yield name, loaded_weight + + weights_iter = _rearange_wqkv(weights) + load_lora_weights(self, weights_iter, adapter_id) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" # modify from vllm diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 70dd8f2159..1059569a09 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -1,17 +1,311 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Iterable, List, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch +import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn import LayerNorm, RMSNorm +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_qkv_proj, + build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .patch import build_model_from_hf_config from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixin -class InternVLChatModel(nn.Module, CudaGraphMixin): +class InternVisionEmbeddings(nn.Module): + """intern vision embedding.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device), ) + + self.patch_embedding = nn.Conv2d(in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + dtype=dtype, + device=device) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter( + torch.empty(1, + self.num_positions, + self.embed_dim, + dtype=dtype, + device=device)) + + def _get_pos_embed(self, pos_embed, H, W): + target_dtype = pos_embed.dtype + pos_embed = pos_embed.float().reshape( + 1, self.image_size // self.patch_size, + self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, + size=(H, W), + mode='bicubic', + align_corners=False).reshape( + 1, -1, H * W).permute(0, 2, + 1).to(target_dtype) + return pos_embed + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values) # shape = [*, channel, width, height] + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, + -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = torch.cat([ + self.position_embedding[:, :1, :], + self._get_pos_embed(self.position_embedding[:, 1:, :], height, + width) + ], + dim=1) + embeddings = embeddings + position_embedding.to(target_dtype) + return embeddings + + +NORM2FN = { + 'rms_norm': RMSNorm, + 'layer_norm': LayerNorm, +} + + +class InternAttention(nn.Module): + """intern vl attention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.qkv = build_qkv_proj( + self.embed_dim, + num_q_heads=self.num_heads, + num_kv_heads=self.num_heads, + head_size=self.head_dim, + bias=config.qkv_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = RMSNorm( + self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device, + ) + self.k_norm = RMSNorm( + self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device, + ) + + self.scale = self.head_dim**-0.5 + + # o_proj + self.proj = build_rowwise_linear(self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True, + tp_align_size=self.head_dim) + + def forward(self, hidden_states): + """forward.""" + + # qkv proj + qkv_states = self.qkv(hidden_states) + q, k, v = self.qkv.split_qkv(qkv_states) + + if self.qk_normalization: + q_shape = q.shape + q = self.q_norm(q.flatten(-2, -1)).view(q_shape) + k = self.k_norm(k.flatten(-2, -1)).view(q_shape) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + + # o proj + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(-2, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class InternMLP(nn.Module): + """intern vl mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + self.act = ACT2FN[config.hidden_act] + + self.fc1 = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + self.fc2 = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class InternVisionEncoderLayer(nn.Module): + """intern vision encoder layer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.intermediate_size = config.intermediate_size + self.norm_type = getattr(config, 'norm_type', 'rms_norm') + + self.attn = InternAttention(config, dtype=dtype, device=device) + self.mlp = InternMLP(config, dtype=dtype, device=device) + self.norm1 = NORM2FN[self.norm_type](self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + self.norm2 = NORM2FN[self.norm_type](self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + + self.ls1 = nn.Parameter( + torch.empty(self.embed_dim, dtype=dtype, device=device)) + self.ls2 = nn.Parameter( + torch.empty(self.embed_dim, dtype=dtype, device=device)) + + def forward( + self, + hidden_states: torch.Tensor, + ): + """forward.""" + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1 + + hidden_states = hidden_states + self.mlp( + self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2 + + return hidden_states + + +class InternVisionEncoder(nn.Module): + """intern vision encoder.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + InternVisionEncoderLayer(config, dtype=dtype, device=device) + for idx in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds, + ): + """forward.""" + hidden_states = inputs_embeds + for _, encoder_layer in enumerate(self.layers): + layer_outputs = encoder_layer(hidden_states, ) + hidden_states = layer_outputs + return hidden_states + + +class InternVisionModel(nn.Module): + """intern vision model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + + self.embeddings = InternVisionEmbeddings(config, + dtype=dtype, + device=device) + self.encoder = InternVisionEncoder(config, dtype=dtype, device=device) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + ): + """forward.""" + assert pixel_values.dim() == 4 + hidden_states = self.embeddings(pixel_values) + + encoder_outputs = self.encoder(inputs_embeds=hidden_states) + last_hidden_state = encoder_outputs + + return last_hidden_state + + +class InternVLChatModel(nn.Module, DeployModelMixin, CudaGraphMixin): def __init__(self, config: PretrainedConfig, @@ -21,31 +315,106 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr + self.select_layer = config.select_layer + llm_config = config.llm_config + self.llm_arch_name = llm_config.architectures[0] + self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM' + + vision_config = config.vision_config + if self.is_mono: + from .internvl_patch import InternVisionPatchModel + self.vision_model = InternVisionPatchModel( + vision_config, + dtype=dtype, + device=device, + ) + else: + self.vision_model = InternVisionModel(vision_config, + dtype=dtype, + device=device) + self.language_model = build_model_from_hf_config(llm_config, dtype=dtype, device=device) - self.llm_arch_name = llm_config.architectures[0] + vit_hidden_size = config.vision_config.hidden_size + llm_hidden_size = config.llm_config.hidden_size + self.downsample_ratio = config.downsample_ratio + self.mlp1 = nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2, + dtype=dtype, + device=device), + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, + llm_hidden_size, + dtype=dtype, + device=device), nn.GELU(), + nn.Linear(llm_hidden_size, + llm_hidden_size, + dtype=dtype, + device=device)) # for Mono-InternVL - self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM' if self.is_mono: assert dtype != torch.float16, ( 'Currently Mono-InternVL does not support FP16 due to' 'numerical instability. Please use BF16 instead.') + self.input_processor = InternVLInputProcessor(self.config, dtype) + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> + # N, H * scale, W * scale, C // (scale ** 2) + x = x.view(n, int(h * scale_factor), int(w * scale_factor), + int(c / (scale_factor * scale_factor))) + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values): + """extract vision feature.""" + assert self.select_layer == -1 + vit_embeds = self.vision_model(pixel_values) + if self.is_mono: + if int(vit_embeds.shape[1]**0.5)**2 != vit_embeds.shape[1]: + vit_embeds = vit_embeds[:, 1:, :] + else: + vit_embeds = vit_embeds[:, 1:, :] + + h = w = int(vit_embeds.shape[1]**0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, + scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, + vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, + pixel_values: torch.Tensor = None, + image_mask: torch.Tensor = None, inputs_embeds: torch.Tensor = None, vision_embedding_indexing: torch.Tensor = None, text_embedding_indexing: torch.Tensor = None, **kwargs, ): + if inputs_embeds is None and pixel_values is not None: + # extract feature + vit_embeds = self.extract_feature(pixel_values) + lang_embeds = self.language_model.get_input_embeddings()(input_ids) + lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) + + inputs_embeds = lang_embeds + if self.is_mono: return self.language_model.forward( input_ids=input_ids, @@ -80,11 +449,38 @@ def prepare_inputs_for_generation( input_ids = context.input_ids position_ids = context.position_ids attn_metadata = context.attn_metadata - # get inputs from context vision_embeddings = context.input_embeddings - vision_embedding_indexing = context.input_embedding_indexing + vision_embedding_indexing = None + + # vision inputs + pixel_values = None + image_mask = None + if context.input_multimodals is not None: + pixel_values = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + pixel_values = [ + data for im_data in pixel_values for data in im_data + ] + if len(pixel_values) > 0: + image_token_id = pixel_values[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + pixel_values = torch.cat([data.data for data in pixel_values]) + else: + pixel_values = None + image_mask = None + if self.is_mono and pixel_values is not None: + vision_embedding_indexing = torch.arange(input_ids.shape[1], + device=input_ids.device) + vision_embedding_indexing = vision_embedding_indexing[ + image_mask[0]] + + # get inputs from context if vision_embeddings is not None and len(vision_embeddings) > 0: + vision_embedding_indexing = context.input_embedding_indexing if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds[:, @@ -104,6 +500,8 @@ def prepare_inputs_for_generation( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_mask=image_mask, inputs_embeds=inputs_embeds, vision_embedding_indexing=vision_embedding_indexing, text_embedding_indexing=text_embedding_indexing, @@ -114,18 +512,96 @@ def prepare_inputs_for_generation( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_mask=image_mask, inputs_embeds=inputs_embeds, ) + def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], + adapter_id: int): + """load lora weights.""" + + if hasattr(self.language_model, 'load_lora_weights'): + return self.language_model.load_lora_weights(weights, adapter_id) + else: + from lmdeploy.pytorch.adapter.adapter import load_lora_weights + + return load_lora_weights(weights, adapter_id) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" - prefix_length = len('language_model.') + lang_prefix = 'language_model.' + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if name.startswith(lang_prefix): + continue + + if 'qkv' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + lang_prefix_length = len(lang_prefix) new_weights = dict() for key, val in weights: - if not key.startswith('language_model.'): + if not key.startswith(lang_prefix): continue - new_key = key[prefix_length:] + new_key = key[lang_prefix_length:] new_weights[new_key] = val self.language_model.load_weights(new_weights.items()) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class InternVLInputProcessor(BaseModelInputProcessor): + """internvl input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + + vision_config = config.vision_config + self.image_size = vision_config.image_size + self.patch_size = vision_config.patch_size + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + self.vision_token_num = self.num_patches // 4 + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + offset = input_mm['offset'] + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor( + data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/internvl_patch.py b/lmdeploy/pytorch/models/internvl_patch.py new file mode 100644 index 0000000000..d13ad2d39b --- /dev/null +++ b/lmdeploy/pytorch/models/internvl_patch.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.configuration_utils import PretrainedConfig + + +class InternVisionEmbeddings(nn.Module): + """mono vision.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device), ) + + self.patch_embedding = nn.Conv2d(in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + dtype=dtype, + device=device) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter( + torch.empty(1, + self.num_positions, + self.embed_dim, + dtype=dtype, + device=device)) + + def _get_pos_embed(self, pos_embed, H, W): + target_dtype = pos_embed.dtype + pos_embed = pos_embed.float().reshape( + 1, self.image_size // self.patch_size, + self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, + size=(H, W), + mode='bicubic', + align_corners=False) + pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2, + 1).to(target_dtype) + return pos_embed + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values) # shape = [*, channel, width, height] + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, + -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = torch.cat([ + self.position_embedding[:, :1, :], + self._get_pos_embed(self.position_embedding[:, 1:, :], height, + width) + ], + dim=1) + embeddings = embeddings + position_embedding.to(target_dtype) + return embeddings + + +class InternVisionPatchModel(nn.Module): + """mono vision.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.embeddings = InternVisionEmbeddings(config, + dtype=dtype, + device=device) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + ): + if len(pixel_values.shape) != 4: + raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') + + hidden_states = self.embeddings(pixel_values)[:, 1:] + return hidden_states diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index f38c5ef02b..1a98c02f03 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -29,7 +29,8 @@ def __init__(self, num_key_value_heads = config.num_key_value_heads hidden_size = config.hidden_size head_dim = getattr(config, 'head_dim', hidden_size // num_heads) - + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) # packed qkv self.qkv_proj = build_qkv_proj( hidden_size, @@ -40,6 +41,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, + num_replicate_kv_heads=num_replicate_kv_heads, ) # rotary embedding @@ -450,22 +452,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) - - -class LlavaLlamaForCausalLM(LlamaForCausalLM): - """llava llama for causallm.""" - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - """load weights.""" - - new_weights = dict() - for key, val in weights: - if key.startswith('model.vision_tower'): - continue - if key.startswith('model.mm_projector'): - continue - if key.startswith('model.image_newline'): - continue - new_weights[key] = val - - super().load_weights(new_weights.items()) diff --git a/lmdeploy/pytorch/models/llava.py b/lmdeploy/pytorch/models/llava.py index 56cb5ca675..751f7343ec 100644 --- a/lmdeploy/pytorch/models/llava.py +++ b/lmdeploy/pytorch/models/llava.py @@ -1,17 +1,443 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Iterable, List, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch +import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.models.llava.configuration_llava import LlavaConfig +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_qkv_proj, + build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .patch import build_model_from_hf_config from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixin -class LlavaForConditionalGeneration(nn.Module, CudaGraphMixin): +class LlavaMultiModalProjector(nn.Module): + + def __init__(self, + config: LlavaConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + + self.linear_1 = nn.Linear(config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=True, + dtype=dtype, + device=device) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear(config.text_config.hidden_size, + config.text_config.hidden_size, + bias=True, + dtype=dtype, + device=device) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class CLIPVisionEmbeddings(nn.Module): + """clip vision embedding.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + torch.empty(self.embed_dim, dtype=dtype, device=device)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + dtype=dtype, + device=device, + ) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding( + self.num_positions, + self.embed_dim, + dtype=dtype, + device=device, + ) + self.register_buffer('position_ids', + torch.arange(self.num_positions, + device=device).expand((1, -1)), + persistent=False) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, + width: int) -> torch.Tensor: + """This method allows to interpolate the pre-trained position + encodings, to be able to use the model on higher resolution images. + + This method is also adapted to support torch.jit tracing. + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing + # to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing( + ) and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + from transformers.utils import torch_int + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, + sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode='bicubic', + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size + or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f' ({self.image_size}*{self.image_size}).') + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding( + self.position_ids) + return embeddings + + +class CLIPAttention(nn.Module): + """clip attention.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.qkv_proj = build_qkv_proj( + self.embed_dim, + num_q_heads=self.num_heads, + num_kv_heads=self.num_heads, + head_size=self.head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + self.scale = self.head_dim**-0.5 + + # o_proj + self.out_proj = build_rowwise_linear(self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + ): + """forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + q, k, v = self.qkv_proj.split_qkv(qkv_states) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if attention_mask is not None and causal_attention_mask is not None: + attn_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attn_mask = causal_attention_mask + else: + attn_mask = attention_mask + + attn_output = F.scaled_dot_product_attention(q, + k, + v, + attn_mask=attn_mask, + scale=self.scale) + + # o proj + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(-2, -1) + attn_output = self.out_proj(attn_output) + return attn_output + + +class CLIPMLP(nn.Module): + """clip mlp.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + from transformers.activations import ACT2FN + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + self.fc2 = build_rowwise_linear( + config.intermediate_size, + config.hidden_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """forward.""" + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class CLIPEncoderLayer(nn.Module): + """clip encoder layer.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention(config, dtype=dtype, device=device) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + self.mlp = CLIPMLP(config, dtype=dtype, device=device) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + ): + """forward.""" + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class CLIPEncoder(nn.Module): + """clip encoder.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + CLIPEncoderLayer(config, dtype=dtype, device=device) + for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + vision_feature_layer: int = -1, + ): + """forward.""" + hidden_states = inputs_embeds + num_vision_layers = len(self.layers) + vision_feature_layer + 1 + for _, encoder_layer in enumerate(self.layers[:num_vision_layers]): + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask=causal_attention_mask, + ) + + hidden_states = layer_outputs + + return hidden_states + + +class CLIPVisionTransformer(nn.Module): + """clip vision transformer.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config, + dtype=dtype, + device=device) + self.pre_layrnorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + self.encoder = CLIPEncoder(config, dtype=dtype, device=device) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + + def forward( + self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + vision_feature_layer: int = -1, + ) -> BaseModelOutputWithPooling: + """forward.""" + hidden_states = self.embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + vision_feature_layer=vision_feature_layer) + + last_hidden_state = encoder_outputs + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=None, + attentions=None, + ) + + +class CLIPVisionModel(nn.Module): + """clip vision model.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.vision_model = CLIPVisionTransformer(config, + dtype=dtype, + device=device) + + def forward(self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + vision_feature_layer: int = -1, + **kwargs): + """forward.""" + return self.vision_model( + pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + vision_feature_layer=vision_feature_layer) + + +def build_vision_model(vision_config, + dtype: torch.dtype = None, + device: torch.device = None): + """build vision model.""" + model_type = vision_config.model_type + + if model_type == 'clip_vision_model': + return CLIPVisionModel(vision_config, dtype, device) + else: + raise NotImplementedError(f'<{model_type}> is not implemented.') + + +class LlavaForConditionalGeneration(nn.Module, CudaGraphMixin, + DeployModelMixin): def __init__(self, config: PretrainedConfig, @@ -22,19 +448,67 @@ def __init__(self, self.config = config self.ctx_mgr = ctx_mgr text_config = config.text_config + + self.vision_tower = build_vision_model(config.vision_config, + dtype=dtype, + device=device) + self.language_model = build_model_from_hf_config(text_config, dtype=dtype, device=device) + self.multi_modal_projector = LlavaMultiModalProjector(config, + dtype=dtype, + device=device) + + self.input_processor = LLavaInputProcessor(config, dtype) + + def get_image_features(self, + pixel_values, + vision_feature_layer: int = -1, + vision_feature_select_strategy: str = 'default'): + """get image features.""" + selected_image_feature = self.vision_tower( + pixel_values, vision_feature_layer=vision_feature_layer)[0] + if vision_feature_select_strategy == 'default': + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == 'full': + selected_image_feature = selected_image_feature + else: + raise ValueError( + f'Unexpected select feature strategy: {vision_feature_select_strategy}' # noqa: E501 + ) + image_features = self.multi_modal_projector(selected_image_feature) + image_features = image_features.flatten(0, 1)[None] + + return image_features + def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, + pixel_values: torch.Tensor = None, + image_mask: torch.Tensor = None, inputs_embeds: torch.Tensor = None, **kwargs, ): + if inputs_embeds is None: + image_features = None + if pixel_values is not None: + vision_feature_layer = self.config.vision_feature_layer + select_strategy = self.config.vision_feature_select_strategy + image_features = self.get_image_features( + pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=select_strategy) + inputs_embeds = self.language_model.get_input_embeddings()( + input_ids) + if pixel_values is not None: + inputs_embeds.masked_scatter_(image_mask[..., None], + image_features) + return self.language_model.forward(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, @@ -59,6 +533,27 @@ def prepare_inputs_for_generation( input_ids = context.input_ids position_ids = context.position_ids attn_metadata = context.attn_metadata + + # vision inputs + pixel_values = None + image_mask = None + if context.input_multimodals is not None: + pixel_values = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + pixel_values = [ + data for im_data in pixel_values for data in im_data + ] + if len(pixel_values) > 0: + image_token_id = pixel_values[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + pixel_values = torch.cat([data.data for data in pixel_values]) + else: + pixel_values = None + image_mask = None + # get inputs from context vision_embeddings = context.input_embeddings vision_embedding_indexing = context.input_embedding_indexing @@ -75,18 +570,404 @@ def prepare_inputs_for_generation( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_mask=image_mask, inputs_embeds=inputs_embeds, ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" - prefix_length = len('language_model.') + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ] + + # vis model + lang_prefix = 'language_model.' + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if name.startswith(lang_prefix): + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + # language model + prefix_length = len(lang_prefix) new_weights = dict() for key, val in weights: - if not key.startswith('language_model.'): + if not key.startswith(lang_prefix): continue new_key = key[prefix_length:] new_weights[new_key] = val self.language_model.load_weights(new_weights.items()) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class LLavaInputProcessor(BaseModelInputProcessor): + """llava input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + offset = input_mm['offset'] + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor( + data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + + from transformers.image_processing_utils import select_best_resolution + + if not isinstance(grid_pinpoints, list): + raise TypeError('grid_pinpoints should be a list of tuples or lists') + + if not isinstance(image_size, (list, tuple)): + image_size = image_size.tolist() + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def unpad_image(tensor, original_size): + """Unpads a PyTorch tensor of a padded and resized image.""" + if not isinstance(original_size, (list, tuple)): + original_size = original_size.tolist() + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(round(original_height * scale_factor, 7)) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding:current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(round(original_width * scale_factor, 7)) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding:current_width - padding] + + return unpadded_tensor + + +def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): + """Calculate the number of patches after the preprocessing for images of + any resolution.""" + from transformers.image_processing_utils import select_best_resolution + if not isinstance(grid_pinpoints, list): + raise TypeError('grid_pinpoints should be a list of tuples or lists') + + if not isinstance(image_size, (list, tuple)): + image_size = image_size.tolist() + + best_resolution = select_best_resolution(image_size, grid_pinpoints) + height, width = best_resolution + + num_patches = (height // patch_size) * (width // patch_size) + # add the base patch + num_patches += 1 + return num_patches + + +class LlavaNextForConditionalGeneration(LlavaForConditionalGeneration): + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__(config=config, + ctx_mgr=ctx_mgr, + dtype=dtype, + device=device) + self.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size, + dtype=dtype, + device=device)) + self.input_processor = LLavaNextInputProcessor(config, dtype) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: int, + vision_feature_select_strategy: str, + ): + # ! infer image_num_patches from image_sizes + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.config.image_grid_pinpoints, + patch_size=self.config.vision_config.image_size, + ) for imsize in image_sizes + ] + if pixel_values.dim() == 5: + # stacked if input is + # (batch_size, num_patches, num_channels, height, width) + _pixel_values_list = [ + pix_val[:num_patch] + for pix_val, num_patch in zip(pixel_values, image_num_patches) + ] + pixel_values = torch.cat(_pixel_values_list, dim=0) + elif pixel_values.dim() != 4: + # otherwise has to be stacked from list of + # (num_patches, num_channels, height, width) + raise ValueError(f'pixel_values of shape {pixel_values.shape}, ' + 'expect to be of 4 or 5 dimensions') + + selected_image_feature = self.vision_tower( + pixel_values, vision_feature_layer=vision_feature_layer)[0] + if vision_feature_select_strategy == 'default': + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == 'full': + selected_image_feature = selected_image_feature + image_features = self.multi_modal_projector(selected_image_feature) + image_features = torch.split(image_features, image_num_patches, dim=0) + return image_features + + def pack_image_features(self, + image_features, + image_sizes, + vision_feature_select_strategy, + image_newline=None): + + new_image_features = [] + feature_lens = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = (self.config.vision_config.image_size // + self.config.vision_config.patch_size) + + if vision_feature_select_strategy == 'default': + expected_num_patches = height * width + elif vision_feature_select_strategy == 'full': + expected_num_patches = height * width + 1 + if expected_num_patches != base_image_feature.shape[0]: + raise ValueError('The number of patches is ' + 'not consistent with the image size.') + + (num_patch_height, + num_patch_width) = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view(num_patch_height, + num_patch_width, height, + width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, + 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, + image_sizes[image_idx]) + if image_newline is not None: + image_feature = torch.cat( + ( + image_feature, + image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1).to( + image_feature.dtype), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), + dim=0) + else: + image_feature = image_feature[0] + if image_newline is not None: + image_feature = torch.cat( + (image_feature, image_newline[None].to(image_feature)), + dim=0) + new_image_features.append(image_feature) + feature_lens.append(image_feature.size(0)) + image_features = torch.cat(new_image_features, dim=0) + return image_features + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + pixel_values: torch.Tensor = None, + image_sizes: torch.Tensor = None, + image_mask: torch.Tensor = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + if inputs_embeds is None: + image_features = None + if pixel_values is not None: + vision_feature_layer = self.config.vision_feature_layer + select_strategy = self.config.vision_feature_select_strategy + image_sizes = image_sizes.tolist() + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=select_strategy) + image_features = self.pack_image_features( + image_features, + image_sizes, + vision_feature_select_strategy=select_strategy, + image_newline=self.image_newline, + ) + image_features = image_features[None] + inputs_embeds = self.language_model.get_input_embeddings()( + input_ids) + if pixel_values is not None: + inputs_embeds.masked_scatter_(image_mask[..., None], + image_features) + + return self.language_model.forward(input_ids=input_ids, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + position_ids=position_ids, + attn_metadata=attn_metadata) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare input.""" + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # vision inputs + pixel_values = None + image_sizes = None + image_mask = None + if context.input_multimodals is not None: + img_mms = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + img_mms = [data for im_data in img_mms for data in im_data] + if len(img_mms) > 0: + image_token_id = img_mms[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + pixel_values = torch.cat( + [data.data.flatten(0, 1) for data in img_mms]) + image_sizes = torch.cat( + [data.meta['image_sizes'] for data in img_mms]) + else: + pixel_values = None + image_sizes = None + + # get inputs from context + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_sizes=image_sizes, + image_mask=image_mask, + inputs_embeds=inputs_embeds, + ) + + +class LLavaNextInputProcessor(BaseModelInputProcessor): + """llava input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + image_sizes = input_mm['image_sizes'] + offset = input_mm['offset'] + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict( + image_sizes=image_sizes, + image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/minicpmv26.py b/lmdeploy/pytorch/models/minicpmv26.py index 725e97d9d7..e551dda841 100644 --- a/lmdeploy/pytorch/models/minicpmv26.py +++ b/lmdeploy/pytorch/models/minicpmv26.py @@ -29,7 +29,8 @@ def __init__(self, num_key_value_heads = config.num_key_value_heads hidden_size = config.hidden_size head_dim = getattr(config, 'head_dim', hidden_size // num_heads) - + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) # packed qkv self.qkv_proj = build_qkv_proj( hidden_size, @@ -40,7 +41,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) # rotary embedding self.apply_rotary_pos_emb = ApplyRotaryEmb() diff --git a/lmdeploy/pytorch/models/mistral.py b/lmdeploy/pytorch/models/mistral.py index 04af4c8526..ad27963093 100644 --- a/lmdeploy/pytorch/models/mistral.py +++ b/lmdeploy/pytorch/models/mistral.py @@ -420,22 +420,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) - - -class LlavaMistralForCausalLM(MistralForCausalLM): - """llava forcausallm.""" - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - """load weights.""" - - new_weights = dict() - for key, val in weights: - if key.startswith('model.vision_tower'): - continue - if key.startswith('model.mm_projector'): - continue - if key.startswith('model.image_newline'): - continue - new_weights[key] = val - - super().load_weights(new_weights.items()) diff --git a/lmdeploy/pytorch/models/mllama.py b/lmdeploy/pytorch/models/mllama.py index 2596fe5299..15b3e9732b 100644 --- a/lmdeploy/pytorch/models/mllama.py +++ b/lmdeploy/pytorch/models/mllama.py @@ -3,23 +3,61 @@ import torch from torch import nn +from torch.nn import functional as F from transformers.models.llama import LlamaConfig -from transformers.models.mllama.modeling_mllama import MllamaTextConfig +from transformers.models.mllama.modeling_mllama import (MllamaTextConfig, + MllamaVisionConfig) +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, - SiluAndMul, build_rotary_embedding) -from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, LayerNorm, RMSNorm, + RopeType, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, + build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.nn.rotary_embedding import Llama3Parameters from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -from .utils.cudagraph import CudaGraphMixin +from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin, next_power_of_2 +from .utils.model import DeployModelMixin MLLAMA_IMAGE_TOKEN_ID = 128256 MLLAMA_IMAGE_TOKEN = '<|image|>' +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, + 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, + # max_num_tiles * target_length) + attention_mask = attention_mask.reshape(batch_size, + max_num_tiles * target_length, 1) + attention_mask = attention_mask * attention_mask.transpose( + -1, -2) * torch.finfo(dtype).min + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + class LlamaAttention(nn.Module): """Rewrite module of LlamaAttention.""" @@ -157,6 +195,7 @@ def __init__(self, self.head_dim, num_kv_heads=self.num_key_value_heads, v_head_size=self.head_dim, + causal=False, ) self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) @@ -579,7 +618,542 @@ def get_logits(self, hidden_states: torch.Tensor): return self.lm_head(hidden_states) -class MllamaForConditionalGeneration(nn.Module, CudaGraphMixin): +class MllamaPrecomputedPositionEmbedding(nn.Module): + """vis position embedding.""" + + def __init__(self, + config: MllamaVisionConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.config = config + self.num_patches = (config.image_size // config.patch_size)**2 + 1 + self.hidden_size = config.hidden_size + + self.gate = nn.Parameter(torch.empty(1, dtype=dtype, device=device)) + + # position embedding + self.embedding = nn.Parameter( + torch.empty(self.num_patches, + self.hidden_size, + dtype=dtype, + device=device)) + + # tile position embedding + self.tile_embedding = nn.Embedding(self.max_aspect_ratio_id + 1, + self.max_num_tiles * + self.num_patches * self.hidden_size, + dtype=dtype, + device=device) + + self._weight_inited = False + + def _init_weight(self): + """init weight.""" + if self._weight_inited: + return + + gate_tanh = self.gate.tanh() + gated_position_embedding = (1 - gate_tanh) * self.embedding + self.gate_tanh = gate_tanh + self.gated_position_embedding = gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size) + + self._weight_inited = True + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + """forward.""" + self._init_weight() + + # position embeddings + hidden_state = hidden_state + self.gated_position_embedding + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size) + gated_tile_position_embedding = (self.gate_tanh * + tile_position_embedding) + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + + def __init__(self, + config: MllamaVisionConfig, + is_gated: bool = True, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.is_gated = is_gated + + self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.hidden_size, + dtype=dtype, + device=device) + if is_gated: + self.gate = nn.Parameter(torch.empty(1, dtype=dtype, + device=device)) + + self._weight_inited = False + + def _init_weight(self): + """init weight.""" + if self._weight_inited: + return + + gate_tanh = self.gate.tanh() + self.gate_tanh = gate_tanh + + self._weight_inited = True + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + self._init_weight() + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, + self.hidden_size) + + if self.is_gated: + embeddings = embeddings * self.gate_tanh + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaVisionAttention(nn.Module): + """mllama vision attention.""" + + def __init__(self, + config: MllamaVisionConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + self.embed_dim = config.hidden_size + self.num_heads = config.attention_heads + self.head_dim = config.hidden_size // config.attention_heads + + # packed qkv + self.qkv_proj = build_qkv_proj( + self.embed_dim, + num_q_heads=self.num_heads, + num_kv_heads=self.num_heads, + head_size=self.head_dim, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # o_proj + self.o_proj = build_rowwise_linear(self.num_heads * self.head_dim, + self.embed_dim, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size = hidden_state.size(0) + qkv_states = self.qkv_proj(hidden_state) + qkv_states = qkv_states.flatten(0, -2) + query, key, value = self.qkv_proj.split_qkv(qkv_states) + + query = query.unflatten(0, (batch_size, -1)) + key = key.unflatten(0, (batch_size, -1)) + value = value.unflatten(0, (batch_size, -1)) + q_seq_len = query.shape[1] + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention(query, + key, + value, + attn_mask=attention_mask) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + + output = self.o_proj(attn_output) + + return output + + +class MllamaVisionMLP(nn.Module): + """mllama vision mlp.""" + + def __init__(self, + config: MllamaVisionConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + self.fc2 = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MllamaVisionEncoderLayer(nn.Module): + """vision encoder layer.""" + + def __init__(self, + config: MllamaVisionConfig, + is_gated: bool, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.hidden_size = config.hidden_size + self.is_gated = is_gated + self.self_attn = MllamaVisionAttention(config, + dtype=dtype, + device=device) + self.mlp = MllamaVisionMLP(config, dtype=dtype, device=device) + + self.input_layernorm = LayerNorm(self.hidden_size, + eps=config.norm_eps, + dtype=dtype, + device=device) + self.post_attention_layernorm = LayerNorm(self.hidden_size, + eps=config.norm_eps, + dtype=dtype, + device=device) + + if is_gated: + self.gate_attn = nn.Parameter( + torch.empty(1, dtype=dtype, device=device)) + self.gate_ffn = nn.Parameter( + torch.empty(1, dtype=dtype, device=device)) + + self._weight_inited = not is_gated + + def _init_weight(self): + """init weight.""" + if self._weight_inited: + return + + self.gate_attn_tanh = self.gate_attn.tanh() + self.gate_ffn_tanh = self.gate_ffn.tanh() + + self._weight_inited = True + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + """forward.""" + self._init_weight() + + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, + attention_mask=attention_mask) + if self.is_gated: + hidden_state = self.gate_attn_tanh * hidden_state + hidden_state = residual + hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + if self.is_gated: + hidden_state = self.gate_ffn_tanh * hidden_state + hidden_state = residual + hidden_state + + outputs = hidden_state + + return outputs + + +class MllamaVisionEncoder(nn.Module): + """vision encoder.""" + + def __init__(self, + config: MllamaVisionConfig, + num_layers=32, + is_gated=False, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + MllamaVisionEncoderLayer(config, + is_gated, + dtype=dtype, + device=device) for _ in range(num_layers) + ]) + self.gradient_checkpointing = False + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + """forward.""" + encoder_states = () + for encoder_layer in self.layers: + encoder_states = encoder_states + (hidden_states, ) + hidden_states = encoder_layer( + hidden_state=hidden_states, + attention_mask=attention_mask, + ) + encoder_states = encoder_states + (hidden_states, ) + + return hidden_states, encoder_states + + +class MllamaVisionModel(nn.Module): + """vision model.""" + + def __init__(self, + config: MllamaVisionConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + + self.config = config + self.image_size = config.image_size + self.patch_size = config.patch_size + self.hidden_size = config.hidden_size + self.intermediate_layers_indices = config.intermediate_layers_indices + self.dtype = dtype + + self.num_patches = (self.image_size // self.patch_size)**2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + padding='valid', + bias=False, + dtype=dtype, + device=device, + ) + + self.class_embedding = nn.Parameter( + torch.empty(self.hidden_size, dtype=dtype, device=device)) + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( + config, + dtype=dtype, + device=device, + ) + + self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( # noqa: E501 + config, + is_gated=True, + dtype=dtype, + device=device, + ) + self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( # noqa: E501 + config, + is_gated=True, + dtype=dtype, + device=device, + ) + + # layer norms + self.layernorm_pre = nn.LayerNorm( + self.hidden_size, + dtype=dtype, + device=device, + ) + self.layernorm_post = nn.LayerNorm( + self.hidden_size, + dtype=dtype, + device=device, + ) + + # encoders + self.transformer = MllamaVisionEncoder( + config, + config.num_hidden_layers, + is_gated=False, + dtype=dtype, + device=device, + ) + self.global_transformer = MllamaVisionEncoder( + config, + config.num_global_layers, + is_gated=True, + dtype=dtype, + device=device, + ) + + def apply_class_embedding(self, + hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, + hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor, + ): + """forward.""" + (batch_size, num_concurrent_media, num_tiles, num_channels, height, + width) = pixel_values.shape + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, + height, width) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1) + + # Patch embedding + patch_embeds = self.patch_embedding(pixel_values.to(self.dtype)) + hidden_state = patch_embeds.flatten(2).transpose(1, 2) + + # Tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, -1, dim) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + + # Add cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, num_patches, dim) + hidden_state = self.gated_positional_embedding(hidden_state, + aspect_ratio_ids) + + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, 0, 0, num_padding_patches + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode='constant', value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + # Prepare attention mask + attention_mask = aspect_ratio_mask.reshape( + batch_size * num_concurrent_media, -1) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.dtype, + ) + + # Apply encoder + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, + dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + ) + hidden_state = output[0] + + hidden_state = self.layernorm_post(hidden_state) + + # Apply global encoder + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), dim) + global_output = self.global_transformer( + hidden_state, + attention_mask=attention_mask, + ) + hidden_state = global_output[0] + + # Remove padding form hidden state + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = hidden_state[:, :, :slice_index] + hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, + num_tiles, num_patches, dim) + + # Collect intermediate layer outputs from encoder output + all_intermediate_hidden_states = output[1] + all_intermediate_hidden_states = [ + all_intermediate_hidden_states[i] + for i in self.intermediate_layers_indices + ] + intermediate_hidden_states = torch.stack( + all_intermediate_hidden_states, dim=-1) + + # Remove padding from intermediate hidden states + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, num_tiles, + num_patches + num_padding_patches, -1) + intermediate_hidden_states = intermediate_hidden_states[:, :, : + slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1) + + # Concatenate final hidden state and intermediate hidden states + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], + dim=-1) + + return hidden_state + + +class MllamaForConditionalGeneration(nn.Module, CudaGraphMixin, + DeployModelMixin): """rewrote model of MllamaForConditionalGeneration.""" packed_modules_mapping = { @@ -602,16 +1176,32 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr + + self.vision_model = MllamaVisionModel( + config.vision_config, + dtype=dtype, + device=device, + ) # build MllamaForCausalLM self.language_model = MllamaForCausalLM(config.text_config, dtype=dtype, device=device) + + self.multi_modal_projector = build_rowwise_linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=True, + dtype=dtype, + device=device, + ) self.dtype = dtype - def flat_encoder_result(self, cross_attention_states: torch.Tensor, - attn_metadata: Any, input_ids: torch.LongTensor): + # preprocessor + self.input_processor = MLlamaInputProcessor(self.config, dtype) + + def flat_encoder_result(self, attn_metadata: Any, + input_ids: torch.LongTensor): # since every state share the same shape - cross_attention_states = torch.cat(cross_attention_states, 0) full_text_row_masked_out_mask = torch.ones( (attn_metadata.q_seqlens.sum(), 1), dtype=torch.bool) start_pos = 0 @@ -621,39 +1211,51 @@ def flat_encoder_result(self, cross_attention_states: torch.Tensor, full_text_row_masked_out_mask[start_pos:img_id] = False start_pos += q_seq_len full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( - cross_attention_states.device) + input_ids.device) - return cross_attention_states, full_text_row_masked_out_mask + return full_text_row_masked_out_mask def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], - cross_attention_states: Optional[torch.Tensor] = None, + pixel_values: torch.Tensor = None, + aspect_ratio_ids: torch.Tensor = None, + aspect_ratio_mask: torch.Tensor = None, attn_metadata: Any = None, inputs_embeds: torch.Tensor = None, cross_attn_metadata: Any = None, **kwargs, ): """model forward, return logits.""" + if cross_attn_metadata is None: full_text_row_masked_out_mask = None # FIXME basically, we want to inference # text requests and image requests separately - elif cross_attention_states is None and ( - cross_attn_metadata.kv_seqlens is None - or int(cross_attn_metadata.kv_seqlens.sum()) == 0): + elif pixel_values is None and (cross_attn_metadata.kv_seqlens is None): full_text_row_masked_out_mask = None elif cross_attn_metadata.is_decoding: - cross_attention_states = None - full_text_row_masked_out_mask = torch.ones( - (attn_metadata.q_seqlens.sum(), 1), - dtype=torch.bool, - device=input_ids.device) + full_text_row_masked_out_mask = input_ids.new_ones( + input_ids.size(-1), 1) else: - cross_attention_states, full_text_row_masked_out_mask = \ - self.flat_encoder_result(cross_attention_states, cross_attn_metadata, input_ids) # noqa + full_text_row_masked_out_mask = self.flat_encoder_result( + cross_attn_metadata, input_ids) # noqa + + cross_attention_states = None + if pixel_values is not None: + cross_attention_states = self.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + ) + cross_attention_states = self.multi_modal_projector( + cross_attention_states) + _, bsz, _, _, image_token_dim = tuple(cross_attention_states.shape) + cross_attention_states = cross_attention_states.view( + bsz, -1, image_token_dim) + hidden_states = self.language_model( input_ids=input_ids, position_ids=position_ids, @@ -670,15 +1272,6 @@ def get_logits(self, hidden_states: torch.Tensor): """compute logits of the model output.""" return self.language_model.get_logits(hidden_states) - def support_cuda_graph( - self, - input_ids: torch.Tensor, - **kwargs, - ): - """support cudagraph.""" - - return False - def get_input_embeddings(self): """get input embeddings.""" return self.language_model.model.get_input_embeddings() @@ -694,14 +1287,35 @@ def prepare_inputs_for_generation( input_ids = context.input_ids position_ids = context.position_ids attn_metadata = context.attn_metadata - cross_attention_states = context.cross_attention_states - if cross_attention_states is not None: - cross_attention_states = [ - t.to(input_ids.device) for t in cross_attention_states - if t is not None - ] cross_attn_metadata = context.cross_attn_metadata + # cross_attn_metadata is None when inputs without image + if cross_attn_metadata is not None and int( + cross_attn_metadata.kv_seqlens.sum()) == 0: + cross_attn_metadata.kv_seqlens = None + + device = input_ids.device + + # process image input + pixel_values = None + aspect_ratio_ids = None + aspect_ratio_mask = None + if context.input_multimodals is not None: + pixel_values = [] + aspect_ratio_ids = [] + aspect_ratio_mask = [] + batched_image_data = [ + input_mm['image'] for input_mm in context.input_multimodals + ] + for image_data in batched_image_data: + for data in image_data: + pixel_values.append(data.data) + aspect_ratio_ids.append(data.meta['aspect_ratio_ids']) + aspect_ratio_mask.append(data.meta['aspect_ratio_mask']) + pixel_values = torch.cat(pixel_values, dim=0).to(device) + aspect_ratio_ids = torch.cat(aspect_ratio_ids, dim=0).to(device) + aspect_ratio_mask = torch.cat(aspect_ratio_mask, dim=0).to(device) + # process vision embeddings vision_embeddings = context.input_embeddings vision_embedding_indexing = context.input_embedding_indexing @@ -719,7 +1333,9 @@ def prepare_inputs_for_generation( past_key_values=past_key_values, attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, - cross_attention_states=cross_attention_states, + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, cross_attn_metadata=cross_attn_metadata, ) @@ -742,8 +1358,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): continue - if 'vision_model' in name or 'multi_modal_projector' in name: - continue if self.config.text_config.tie_word_embeddings and 'lm_head.weight' in name: # noqa continue for (param_name, weight_name, shard_id) in stacked_params_mapping: @@ -756,3 +1370,161 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) + + def support_cuda_graph( + self, + input_ids: torch.Tensor, + attn_metadata: Any, + cross_attn_metadata: Any, + **kwargs, + ): + """support cudagraph.""" + + if not attn_metadata.is_decoding: + return False + + if cross_attn_metadata is None: + return False + + if cross_attn_metadata.kv_seqlens is None: + return False + + return True + + def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """make cudagraph buffers from forward inputs.""" + input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, + **kwargs) + + device = graph_meta.device + max_batches = graph_meta.max_batchs + input_buffers['cross_kv_seqlens'] = torch.zeros(max_batches, + dtype=torch.int64, + device=device) + + return input_buffers + + def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """fill cudagraph buffers from forward inputs.""" + input_buffers = graph_meta.input_buffers + + new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, + **kwargs) + + attn_metadata = new_inputs['attn_metadata'] + cross_attn_metadata = new_inputs['cross_attn_metadata'] + block_offsets = attn_metadata.block_offsets + batch_size, _ = block_offsets.size() + + kv_seqlens = cross_attn_metadata.kv_seqlens + if kv_seqlens.data_ptr() != input_buffers['cross_kv_seqlens'].data_ptr( + ): + input_buffers['cross_kv_seqlens'].zero_() + input_buffers['cross_kv_seqlens'][:batch_size] = kv_seqlens + + new_batch_size = next_power_of_2(batch_size) + cross_attn_metadata.block_offsets = input_buffers[ + 'block_offsets'][:new_batch_size] + cross_attn_metadata.q_start_loc = input_buffers[ + 'q_start_loc'][:new_batch_size] + cross_attn_metadata.q_seqlens = input_buffers[ + 'q_seqlens'][:new_batch_size] + cross_attn_metadata.kv_seqlens = input_buffers[ + 'cross_kv_seqlens'][:new_batch_size] + + new_inputs['cross_attn_metadata'] = cross_attn_metadata + return new_inputs + + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """update model meta.""" + model_metas = context.model_metas + if model_metas is None: + batch_size = context.q_seqlens.size(0) + model_metas = [dict(cross_kv_len=0) for _ in range(batch_size)] + + if context.is_decoding: + return model_metas + + vision_inputs = context.vision_inputs + if vision_inputs is None: + return model_metas + + input_mms = vision_inputs.input_multimodals + if input_mms is None: + return model_metas + + config = self.config.vision_config + image_size = config.image_size + patch_size = config.patch_size + wh = image_size // patch_size + img_kv_len = wh * wh + 1 + img_kv_len = img_kv_len * 4 + + new_model_metas = [] + for idx, input_mm in enumerate(input_mms): + if input_mm is None: + new_model_metas.append(model_metas[idx]) + images = input_mm['image'] + num_img = len(images) + + cross_kv_len = 0 + if model_metas[idx] is not None: + cross_kv_len = model_metas[idx].get('cross_kv_len', + cross_kv_len) + cross_kv_len += img_kv_len * num_img + new_model_metas.append(dict(cross_kv_len=cross_kv_len)) + + return model_metas + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class MLlamaInputProcessor(BaseModelInputProcessor): + """mllama input processor.""" + + def __init__(self, config: LlamaConfig, dtype: torch.dtype) -> None: + self.config = config + self.dtype = dtype + + vision_config = self.config.vision_config + image_size = vision_config.image_size + patch_size = vision_config.patch_size + wh = image_size // patch_size + encoder_len = wh * wh + 1 + encoder_len = encoder_len * 4 + self.encoder_len = encoder_len + + def preprocess_input(self, input_ids, input_multimodals, **kwargs): + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'] + aspect_ratio_ids = input_mm['aspect_ratio_ids'] + aspect_ratio_mask = input_mm['aspect_ratio_mask'] + offset = input_mm['offset'] + + if pixel_values.dtype != self.dtype: + pixel_values = pixel_values.to(self.dtype) + + mm_data = MultiModalTensor( + data=pixel_values, + start=offset, + end=offset + 1, + encoder_len=self.encoder_len, + meta=dict(aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 1059bfee4e..e7b460026a 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -85,14 +85,10 @@ # llava MODULE_MAP.update( { - 'LlavaLlamaForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlavaLlamaForCausalLM', - 'LlavaMistralForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.LlavaMistralForCausalLM', 'LlavaForConditionalGeneration': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration', # noqa: E501 'LlavaNextForConditionalGeneration': # noqa: E501 - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration' + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaNextForConditionalGeneration' # noqa: E501 }) # qwen @@ -158,7 +154,7 @@ # phi3 vision MODULE_MAP.update({ 'Phi3VForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.Phi3VForCausalLM', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3_v.Phi3VForCausalLM', }) # phi-3.5-moe diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py index 9da1b9f4ea..9604b19af5 100644 --- a/lmdeploy/pytorch/models/patch.py +++ b/lmdeploy/pytorch/models/patch.py @@ -8,6 +8,7 @@ import torch from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import load_state_dict from lmdeploy.utils import get_logger @@ -250,6 +251,10 @@ def add_adapters(model: torch.nn.Module, ranks, scalings = get_ranks_and_scalings(target_name, adapter_cfgs, device=device) + # split in case target_name has '.' like 'attention.wo' + # which cannot be used as name of a module + # and it's not aligned with key in model.packed_modules_mapping + target_name = target_name.split('.')[-1] found_mods, pack_idx = find_all_target(model, target_name) sum_rank = ranks.sum().item() @@ -295,7 +300,9 @@ def add_adapters(model: torch.nn.Module, for name, path in adapters.items(): adapter_id = adapter_id_map[name] checkpoint_path = f'{path}/adapter_model.bin' - state_dict = torch.load(checkpoint_path, map_location=device) + if not osp.exists(checkpoint_path): + checkpoint_path = f'{path}/adapter_model.safetensors' + state_dict = load_state_dict(checkpoint_path, map_location=device) if hasattr(model, 'load_lora_weights'): model.load_lora_weights(state_dict.items(), adapter_id=adapter_id) diff --git a/lmdeploy/pytorch/models/phi3.py b/lmdeploy/pytorch/models/phi3.py index f9477fdab8..288fdf3b19 100644 --- a/lmdeploy/pytorch/models/phi3.py +++ b/lmdeploy/pytorch/models/phi3.py @@ -435,7 +435,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) - - -class Phi3VForCausalLM(Phi3ForCausalLM): - ... diff --git a/lmdeploy/pytorch/models/phi3_v.py b/lmdeploy/pytorch/models/phi3_v.py new file mode 100644 index 0000000000..c4bf72c767 --- /dev/null +++ b/lmdeploy/pytorch/models/phi3_v.py @@ -0,0 +1,476 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig + +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn.linear import build_rowwise_linear +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + +from .phi3 import Phi3ForCausalLM, Phi3Model +from .utils.model import DeployModelMixin + +CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(attention_dropout=0.0, + dropout=0.0, + hidden_act='quick_gelu', + hidden_size=1024, + image_size=336, + initializer_factor=1.0, + initializer_range=0.02, + intermediate_size=4096, + layer_norm_eps=1e-05, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + projection_dim=768) + + +class Phi3ImageEmbedding(nn.Module): + """image embedding.""" + + def __init__(self, + config: PretrainedConfig, + wte=None, + dtype: torch.dtype = None, + device: torch.device = None, + **kwargs): + super().__init__() + self.config = config + hidden_size = config.n_embd if hasattr( + config, 'n_embd') else config.hidden_size + + self.wte = wte + + if (isinstance(config.img_processor, dict) and + config.img_processor.get('name', None) == 'clip_vision_model'): + assert 'model_name' in config.img_processor, ( + 'model_name must be provided for CLIPVisionModel') + assert 'image_dim_out' in config.img_processor, ( + 'image_dim_out must be provided for CLIPVisionModel') + assert 'num_img_tokens' in config.img_processor, ( + 'num_img_tokens must be provided for CLIPVisionModel') + assert config.img_processor[ + 'model_name'] == 'openai/clip-vit-large-patch14-336' + clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG + self.img_processor = CLIPVisionModel(clip_config).to(device).to( + dtype) + image_dim_out = config.img_processor['image_dim_out'] + self.num_img_tokens = config.img_processor['num_img_tokens'] + else: + raise NotImplementedError( + f'img_processor = {config.img_processor}, not implemented') + + self.image_dim_out = image_dim_out + self.img_sizes = None + + self.use_hd_transform = kwargs.get('use_hd_transform', False) + self.with_learnable_separator = kwargs.get('with_learnable_separator', + False) + self.hd_transform_order = kwargs.get('hd_transform_order', 'glb_sub') + # with_hd_transform and with_learnable_separator should have same value + assert (self.use_hd_transform == self.with_learnable_separator), ( + 'use_hd_transform and with_learnable_separator ' + 'should have same value') + if self.with_learnable_separator: + assert self.use_hd_transform, ( + 'learnable separator is only for hd transform') + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter( + torch.empty([1, 1, self.image_dim_out * 4], + dtype=dtype, + device=device)) + self.sub_GN = nn.Parameter( + torch.empty([1, 1, 1, self.image_dim_out * 4], + dtype=dtype, + device=device)) + + projection_cls = kwargs.get('projection_cls', 'linear') + if projection_cls == 'linear': + self.img_projection = nn.Linear(image_dim_out, + hidden_size, + dtype=dtype, + device=device) + elif projection_cls == 'mlp' and self.use_hd_transform: + dim_projection = hidden_size + depth = 2 + layers = [ + nn.Linear(image_dim_out * 4, + dim_projection, + dtype=dtype, + device=device) + ] + for _ in range(1, depth): + layers.extend([ + nn.GELU(), + nn.Linear(dim_projection, + dim_projection, + dtype=dtype, + device=device) + ]) + self.img_projection = nn.Sequential(*layers) + elif projection_cls == 'mlp': + dim_projection = hidden_size + depth = 2 + layers = [ + nn.Linear(image_dim_out, + dim_projection, + dtype=dtype, + device=device) + ] + for _ in range(1, depth): + layers.extend([ + nn.GELU(), + nn.Linear(dim_projection, + dim_projection, + dtype=dtype, + device=device) + ]) + self.img_projection = nn.Sequential(*layers) + else: + raise NotImplementedError( + f'projection_cls = {projection_cls}, not implemented') + + self.vocab_size = config.vocab_size + self.img_features = None + + if isinstance(config.img_processor, dict): + self.layer_idx = config.img_processor.get('layer_idx', -2) + self.type_feature = config.img_processor.get( + 'type_feature', 'patch') + else: + self.layer_idx = -2 + self.type_feature = 'patch' + + def get_img_features(self, + img_embeds: torch.FloatTensor) -> torch.FloatTensor: + LAYER_IDX = self.layer_idx + TYPE_FEATURE = self.type_feature + + img_processor_output = self.img_processor(img_embeds, + output_hidden_states=True) + img_feature = img_processor_output.hidden_states[LAYER_IDX] + + if TYPE_FEATURE == 'patch': + patch_feature = img_feature[:, 1:] + return patch_feature + + if TYPE_FEATURE == 'cls_patch': + return img_feature + + raise NotImplementedError + + def forward( + self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + image_sizes=None, + image_mask: torch.Tensor = None, + ) -> torch.FloatTensor: + """forward.""" + + target_device = pixel_values.device + target_dtype = pixel_values.dtype + + img_embeds = pixel_values + img_sizes = image_sizes + img_sizes = img_sizes.cpu() + + if self.use_hd_transform and img_sizes is not None and len(img_sizes): + assert img_embeds.ndim == 5, f'img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform' # noqa E501 + # img_embeds: (num_images, max_num_crops, 3, H, W) + # img_sizes: (num_images, 2).view(1, -1) + + bs = img_embeds.shape[0] + # Nx(HW)xC + img_features = self.get_img_features(img_embeds.flatten(0, 1)) + base_feat_height = base_feat_width = int( + img_features.shape[1]**0.5) + + assert base_feat_height == 24 and base_feat_width == 24, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect 24x24 features for hd transform' # noqa E501 + + # bs x max_num_crops x (24x24) x C + img_features = img_features.view( + bs, -1, base_feat_height * base_feat_width, self.image_dim_out) + C = self.image_dim_out + H = base_feat_height + + output_imgs = [] + output_len = [] + # training is tensor, inference is list + if isinstance(img_sizes, torch.Tensor): + img_sizes = img_sizes.view(-1, 2) + for _bs in range(bs): + h, w = img_sizes[_bs] + h = h // 336 + w = w // 336 + B_ = h * w + + # 1 x (24x24) x 1024 + global_img_feature = img_features[_bs, :1] + + # 1 x 12 x 12 x 4096 + glb_img = global_img_feature.reshape( + 1, H // 2, 2, H // 2, 2, + C).permute(0, 1, 3, 2, 4, + 5).reshape(1, H // 2, H // 2, 4 * C) + temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1) + + # 1 x 156 x 4096 + glb_img = torch.cat([glb_img, temp_glb_GN], + dim=2).reshape(1, -1, 4 * C) + + # (max_num_crops-1) x (12x12) x C + sub_img = img_features[_bs, 1:] + # 16x574x1024 + # get rid of padding sub_img + sub_img = sub_img[:B_] + + # (num_crops, 12, 2, 12, 2, 1024) + # ->(num_crops, 12, 12, 2, 2, 1024) + # -> (num_crops, 12*12, 4*1024) + sub_img = (sub_img.reshape(B_, H // 2, 2, H // 2, 2, + C).permute(0, 1, 3, 2, 4, 5)) + sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute( + 0, 1, 3, 2, 4, 5).reshape(1, h * 12, w * 12, 4 * C) + temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1) + sub_img = torch.cat([sub_img, temp_sub_GN], + dim=2).reshape(1, -1, 4 * C) + # (1, num_img_tokens, 1024*4) + + # glb + sub + if self.hd_transform_order == 'glb_sub': + output_imgs.append( + torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) + elif self.hd_transform_order == 'sub_glb': + output_imgs.append( + torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) + else: + raise NotImplementedError( + f'hd_transform_order = {self.hd_transform_order}' + ) # noqa E501 + + temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12) + assert temp_len == output_imgs[-1].shape[ + 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' # noqa E501 + output_len.append(temp_len) + + img_set_tensor = [] + for _output_img in output_imgs: + img_feature_proj = self.img_projection( + _output_img.to(target_device).to(target_dtype)) + img_feature_proj = img_feature_proj.flatten(0, 1) + img_set_tensor.append(img_feature_proj) + img_set_tensor = torch.cat(img_set_tensor)[None] + elif img_embeds.ndim == 4: + tt = (self.get_img_features(img_embeds).to(target_device).to( + target_dtype).reshape(-1, self.image_dim_out)) + img_set_tensor = self.img_projection( + tt) # adapted visual features. + elif img_embeds.ndim == 3: + tt = (img_embeds.to(target_device).to(target_dtype).view( + -1, self.image_dim_out)) + img_set_tensor = self.img_projection( + tt) # adapted visual features. + else: + raise NotImplementedError + + hidden_states = self.wte(input_ids) + + hidden_states.masked_scatter_(image_mask[..., None], img_set_tensor) + + return hidden_states + + +class Phi3VModel(Phi3Model): + """phi3v model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__(config=config, dtype=dtype, device=device) + + self.vision_embed_tokens = None + if isinstance(config.embd_layer, dict): + # vision embedding layer + embedding_config = { + 'embedding_cls': config.embd_layer['embedding_cls'], + **config.embd_layer + } + self.vision_embed_tokens = Phi3ImageEmbedding( + config, + wte=self.embed_tokens, + dtype=dtype, + device=device, + **embedding_config) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + image_mask: torch.Tensor = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of LlamaModel.forward.""" + + if inputs_embeds is None and pixel_values is not None: + inputs_embeds = self.vision_embed_tokens( + input_ids, + pixel_values, + image_sizes, + image_mask, + ) + + return super().forward( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + +class Phi3VForCausalLM(Phi3ForCausalLM, DeployModelMixin): + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__(config, ctx_mgr, dtype=dtype, device=device) + self.config = config + self.ctx_mgr = ctx_mgr + # build model + self.model = Phi3VModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + self.input_processor = Phi3VInputProcessor(config, dtype) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + pixel_values: torch.Tensor = None, + image_sizes: torch.Tensor = None, + image_mask: torch.Tensor = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """forward.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_sizes=image_sizes, + image_mask=image_mask, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare input.""" + output = super().prepare_inputs_for_generation( + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + context=context) + + # vision inputs + pixel_values = None + if context.input_multimodals is not None: + input_mms = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + input_mms = [data for im_data in input_mms for data in im_data] + if len(input_mms) > 0: + pixel_values = torch.cat([data.data for data in input_mms]) + image_sizes = torch.cat( + [data.meta['image_sizes'] for data in input_mms]) + image_token_id = input_mms[0].meta['image_token_id'] + image_mask = output['input_ids'] == image_token_id + output['pixel_values'] = pixel_values + output['image_sizes'] = image_sizes + output['image_mask'] = image_mask + + return output + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + super().load_weights(weights) + + vis_prefix = 'vision_embed_tokens.' + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if not (vis_prefix in name): + continue + param = params_dict[name] + load_weight(param, loaded_weight) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class Phi3VInputProcessor(BaseModelInputProcessor): + """Phi3V input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + image_sizes = input_mm['image_sizes'] + offset = input_mm['offset'] + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict( + image_sizes=image_sizes, + image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/qwen2.py b/lmdeploy/pytorch/models/qwen2.py index 82be75e167..38773c21e1 100644 --- a/lmdeploy/pytorch/models/qwen2.py +++ b/lmdeploy/pytorch/models/qwen2.py @@ -29,7 +29,8 @@ def __init__(self, num_key_value_heads = config.num_key_value_heads hidden_size = config.hidden_size head_dim = getattr(config, 'head_dim', hidden_size // num_heads) - + num_replicate_kv_heads = getattr(config, + 'num_replicate_key_value_heads', 1) # packed qkv self.qkv_proj = build_qkv_proj( hidden_size, @@ -40,7 +41,7 @@ def __init__(self, quant_config=quantization_config, dtype=dtype, device=device, - ) + num_replicate_kv_heads=num_replicate_kv_heads) # rotary embedding self.apply_rotary_pos_emb = ApplyRotaryEmb() diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py index b10baaa4d5..4e2b1017b5 100644 --- a/lmdeploy/pytorch/models/qwen2_vl.py +++ b/lmdeploy/pytorch/models/qwen2_vl.py @@ -1,18 +1,24 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Callable, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import torch from torch import nn from transformers.configuration_utils import PretrainedConfig +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, - SiluAndMul, build_rotary_embedding) -from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, FlashAttention, + LayerNorm, RMSNorm, RopeType, SiluAndMul, + build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, + build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin, next_power_of_2 +from .utils.model import DeployModelMixin def _apply_mrope_selection(hidden_states: torch.Tensor, @@ -337,7 +343,337 @@ def get_input_embeddings(self): return self.embed_tokens -class Qwen2VLForConditionalGeneration(nn.Module, CudaGraphMixin): +class PatchEmbed(nn.Module): + """Patch Embed.""" + + def __init__(self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + dtype: torch.dtype = None, + device: torch.device = None) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + dtype=dtype, + device=device) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view(-1, self.in_channels, + self.temporal_patch_size, + self.patch_size, self.patch_size) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view( + -1, self.embed_dim) + return hidden_states + + +class VisionRotaryEmbedding(nn.Module): + """vision rotary embedding.""" + + def __init__(self, + dim: int, + theta: float = 10000.0, + device: torch.device = None) -> None: + super().__init__() + inv_freq = 1.0 / (theta**( + torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class VisionAttention(nn.Module): + """Vision attention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + dim = config.embed_dim + num_heads = config.num_heads + head_dim = dim // num_heads + self.head_dim = head_dim + + # packed qkv + self.qkv = build_qkv_proj( + dim, + num_q_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.attention = FlashAttention( + num_heads, + head_dim, + causal=False, + ) + + # o_proj + self.proj = build_rowwise_linear(dim, + dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor] + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + # qkv proj + qkv_states = self.qkv(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + q, k, v = self.qkv.split_qkv(qkv_states) + + cos, sin = rotary_pos_emb + q, k = self.apply_rotary_pos_emb(q, k, cos, sin) + + attn_output = self.attention( + q, + k, + v, + q_start_loc=cu_seqlens[:-1], + q_seqlens=cu_seqlens[1:] - cu_seqlens[:-1], + ) + + attn_output = attn_output.reshape(seq_length, -1) + + # o proj + attn_output = self.proj(attn_output) + return attn_output + + +class VisionMlp(nn.Module): + """Vision mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + dim = config.embed_dim + hidden_dim = int(config.embed_dim * config.mlp_ratio) + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.fc1 = build_colwise_linear( + dim, + hidden_dim, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + if config.hidden_act in [ + 'gelu', 'gelu_fast', 'quick_gelu', 'gelu_python' + ]: + self.act = nn.GELU() + else: + self.act = ACT2FN[config.hidden_act] + + # down + self.fc2 = build_rowwise_linear(hidden_dim, + dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + return self.fc2(self.act(self.fc1(x))) + + +class Qwen2VLVisionBlock(nn.Module): + """Vision block.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + self.norm1 = LayerNorm(config.embed_dim, + eps=1e-6, + dtype=dtype, + device=device) + self.norm2 = LayerNorm(config.embed_dim, + eps=1e-6, + dtype=dtype, + device=device) + + self.attn = VisionAttention(config, dtype=dtype, device=device) + + self.mlp = VisionMlp(config, dtype=dtype, device=device) + + def forward(self, + hidden_states, + cu_seqlens, + rotary_pos_emb, + residual: Optional[torch.Tensor] = None) -> torch.Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.norm1(hidden_states) + else: + hidden_states, residual = self.norm1(hidden_states, residual) + + hidden_states = self.attn(hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb) + + hidden_states, residual = self.norm2(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class PatchMerger(nn.Module): + """PatchMerger.""" + + def __init__(self, + dim: int, + context_dim: int, + spatial_merge_size: int = 2, + dtype: torch.dtype = None, + device: torch.device = None) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = nn.LayerNorm(context_dim, + eps=1e-6, + dtype=dtype, + device=device) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, + self.hidden_size, + dtype=dtype, + device=device), + nn.GELU(), + nn.Linear(self.hidden_size, dim, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +class Qwen2VisionTransformerPretrainedModel(nn.Module): + """Vision transformer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = PatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.embed_dim, + dtype=dtype, + device=device, + ) + + head_dim = config.embed_dim // config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2, + device=device) + + self.blocks = nn.ModuleList([ + Qwen2VLVisionBlock(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.depth) + ]) + self.merger = PatchMerger(dim=config.hidden_size, + context_dim=config.embed_dim, + spatial_merge_size=config.spatial_merge_size, + dtype=dtype, + device=device) + + def rot_pos_emb(self, grid_thw): + """rotary position embedding.""" + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor) -> torch.Tensor: + """forward.""" + hidden_states = self.patch_embed(hidden_states) + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + + residual = None + for blk in self.blocks: + hidden_states, residual = blk(hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + residual=residual) + + hidden_states = hidden_states + residual + + return self.merger(hidden_states) + + +class Qwen2VLForConditionalGeneration(nn.Module, DeployModelMixin, + CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -360,6 +696,16 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr + + # preprocessor + self.input_processor = Qwen2VLInputProcessor(self.config) + + # build vision model + self.visual = Qwen2VisionTransformerPretrainedModel( + config.vision_config, + dtype=dtype, + device=device, + ) # build model self.model = Qwen2Model(config, dtype=dtype, device=device) # build lm_head @@ -377,9 +723,26 @@ def forward( attn_metadata: Any = None, inputs_embeds: torch.Tensor = None, mrope_position_ids: torch.Tensor = None, + pixel_values: torch.Tensor = None, + vis_cu_seqlens: torch.Tensor = None, + vis_pos_emb: torch.Tensor = None, + image_mask: torch.Tensor = None, **kwargs, ): """model forward, return logits.""" + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + if pixel_values is not None: + dtype = inputs_embeds.dtype + pixel_values = pixel_values.to(dtype) + vis_pos_emb = (vis_pos_emb[0].to(dtype), + vis_pos_emb[1].to(dtype)) + image_embeds = self.visual(pixel_values, + cu_seqlens=vis_cu_seqlens, + rotary_pos_emb=vis_pos_emb) + inputs_embeds = inputs_embeds.masked_scatter( + image_mask[..., None], image_embeds) + hidden_states = self.model( input_ids=input_ids, position_ids=position_ids, @@ -416,6 +779,36 @@ def prepare_inputs_for_generation( position_ids = context.position_ids attn_metadata = context.attn_metadata + pixel_values = None + vis_cu_seqlens = None + vis_pos_emb = None + image_mask = None + if context.input_multimodals is not None: + image_data = [ + input_mm['image'] for input_mm in context.input_multimodals + ] + + if len(image_data) > 0: + # flatten batch + image_data = [ + data for im_data in image_data for data in im_data + ] + pixel_values = torch.cat([data.data for data in image_data]) + image_token_id = image_data[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + grid_thw = torch.cat( + [data.meta['grid_thw'] for data in image_data]).cpu() + vis_pos_emb = self.visual.rot_pos_emb(grid_thw) + vis_cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).to(pixel_values.device) + vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, + dtype=torch.int32) + vis_pos_emb = vis_pos_emb.repeat(1, 2) + vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin()) + + mrope_position_ids = getattr(context, 'mrope_position_ids', None) + # process vision embeddings vision_embeddings = context.input_embeddings vision_embedding_indexing = context.input_embedding_indexing @@ -433,7 +826,11 @@ def prepare_inputs_for_generation( past_key_values=past_key_values, attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, - mrope_position_ids=context.mrope_position_ids, + mrope_position_ids=mrope_position_ids, + pixel_values=pixel_values, + vis_cu_seqlens=vis_cu_seqlens, + vis_pos_emb=vis_pos_emb, + image_mask=image_mask, ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -450,8 +847,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - if 'visual' in name: - continue if 'rotary_emb.inv_freq' in name: continue if ('rotary_emb.cos_cached' in name @@ -467,8 +862,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): load_weight(param, loaded_weight, shard_id=shard_id) break else: - param = params_dict[name] - load_weight(param, loaded_weight) + if '.qkv.' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): """make cudagraph buffers from forward inputs.""" @@ -510,3 +912,130 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): 'mrope_position_ids'] return new_inputs + + def _update_model_meta_decoding(self, context: StepContext): + """update model meta for decoding.""" + model_metas = context.model_metas + position_ids = context.position_ids + + mrope_deltas = [meta['mrope_delta'] for meta in model_metas] + mrope_deltas = position_ids.new_tensor(mrope_deltas) + mrope_position_ids = position_ids + mrope_deltas[None] + mrope_position_ids = mrope_position_ids.expand(3, -1) + + context.mrope_position_ids = mrope_position_ids + return model_metas + + def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): + """get mrope ids.""" + t, h, w = grid_thw + h //= 2 + w //= 2 + stride = torch.tensor([h * w, w, 1], device=device)[:, None] + size = torch.tensor([t, h, w], device=device)[:, None] + pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) + pos_ids = pos_ids // stride % size + return pos_ids + + def _update_model_meta_prefilling(self, context: StepContext): + """update model meta for prefilling.""" + model_metas = context.model_metas + input_multimodals = context.input_multimodals + if input_multimodals is None: + input_multimodals = [None] * len(model_metas) + position_ids = context.position_ids + batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) + mrope_position_ids = [] + new_model_metas = [] + for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, + input_multimodals): + images = [] + if input_mm is not None: + images = input_mm['image'] + if model_meta is None or 'mrope_delta' not in model_meta: + mrope_delta = 0 + else: + mrope_delta = model_meta['mrope_delta'] + + pos_start = pos_ids[0].item() + mrope_pos_ids = pos_ids + mrope_delta + mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() + for img in images: + grid_thw = img.meta['grid_thw'][0].tolist() + _, h, w = grid_thw + h //= 2 + w //= 2 + num_pad = img.end - img.start - max(h, w) + mrope_delta -= num_pad + fill_start = img.start - pos_start + fill_end = img.end - pos_start + img_pos_ids = self._get_multimodal_pos_ids( + grid_thw, pos_ids.device) + img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] + mrope_pos_ids[:, fill_end:] -= num_pad + mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids + + mrope_position_ids.append(mrope_pos_ids) + new_model_metas.append(dict(mrope_delta=mrope_delta)) + + mrope_position_ids = torch.cat(mrope_position_ids, dim=1) + context.mrope_position_ids = mrope_position_ids + + return new_model_metas + + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """update model meta.""" + if context.is_decoding: + return self._update_model_meta_decoding(context) + else: + return self._update_model_meta_prefilling(context) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +InputMultiModalType = List[Dict[str, Any]] + + +class Qwen2VLInputProcessor(BaseModelInputProcessor): + """qwen2 input processor.""" + + def __init__(self, config: PretrainedConfig) -> None: + self.config = config + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'] + image_grid_thw = input_mm['image_grid_thw'] + offset = input_mm['offset'] + start = offset + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor(data=pixel_values, + start=start, + end=start + num_pad, + meta=dict( + grid_thw=image_grid_thw, + image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/utils/model.py b/lmdeploy/pytorch/models/utils/model.py new file mode 100644 index 0000000000..99bd4c4bfb --- /dev/null +++ b/lmdeploy/pytorch/models/utils/model.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Iterable, List, Optional, Tuple + +import torch + +from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor +from lmdeploy.pytorch.model_inputs import StepContext + + +class DeployModelMixin: + + def forward(self, *args, **kwargs): + """forward of model.""" + raise NotImplementedError('Not Implemented') + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + raise NotImplementedError('Not Implemented') + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + raise NotImplementedError('Not Implemented') + + def get_logits(self, hidden_states: torch.Tensor): + """compute logits of the model output.""" + return hidden_states + + def update_weights(self): + """update weights.""" + pass + + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """update model meta.""" + return None + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return None diff --git a/lmdeploy/pytorch/models/utils/multimodal.py b/lmdeploy/pytorch/models/utils/multimodal.py new file mode 100644 index 0000000000..aebcaf4073 --- /dev/null +++ b/lmdeploy/pytorch/models/utils/multimodal.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs + +PreparedInputs = Tuple[List[int], MultiModalInputs] + + +class MultiModalMixin: + + def prepare_multimodal_input(self, input_ids, input_multimodals, + **kwargs) -> PreparedInputs: + """prepare multimodals inputs.""" + raise NotImplementedError('prepare input not implemented.') diff --git a/lmdeploy/pytorch/multimodal/__init__.py b/lmdeploy/pytorch/multimodal/__init__.py new file mode 100644 index 0000000000..c3e8c6a16f --- /dev/null +++ b/lmdeploy/pytorch/multimodal/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .data_type import MultiModalData, MultiModalTensor + +__all__ = ['MultiModalData', 'MultiModalTensor'] diff --git a/lmdeploy/pytorch/multimodal/data_type.py b/lmdeploy/pytorch/multimodal/data_type.py new file mode 100644 index 0000000000..886c7ffbd0 --- /dev/null +++ b/lmdeploy/pytorch/multimodal/data_type.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass, fields +from typing import Any, Dict, List, Union + +import torch +from torch import Tensor +from torch import distributed as dist + + +class MultiModalData: + pass + + +MultiModalDataList = List[MultiModalData] + +NestedTensor = Union[Tensor, List[Tensor]] + + +def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'): + """broadcast tensor.""" + if value.device.type == 'meta': + value = torch.empty_like(value, device=device) + dist.broadcast(value, src) + return value + + +@dataclass +class MultiModalTensor: + data: NestedTensor + start: int + end: int = None + encoder_len: int = None + meta: Dict[str, Any] = None + + def __post_init__(self): + if self.end is None: + self.end = self.start + + def to_device(self, device: str, non_blocking: bool = False): + """to device.""" + out_dict = dict() + for f in fields(self): + k = f.name + if k in ('data', 'meta'): + continue + v = getattr(self, k) + out_dict[k] = v + + if isinstance(self.data, Tensor): + data = self.data.to(device=device, non_blocking=non_blocking) + else: + data = [ + d.to(device=device, non_blocking=non_blocking) + for d in self.data + ] + out_dict['data'] = data + + new_meta = None + if self.meta is not None: + new_meta = dict() + for k, v in self.meta.items(): + if isinstance(v, Tensor): + v = v.to(device=device, non_blocking=non_blocking) + elif hasattr(v, 'to_device'): + v = v.to_device(device=device, non_blocking=non_blocking) + new_meta[k] = v + + out_dict['meta'] = new_meta + return MultiModalTensor(**out_dict) + + def broadcast(self): + """broadcast inputs tensors.""" + out_dict = dict() + for f in fields(self): + k = f.name + if k in ('data', 'meta'): + continue + v = getattr(self, k) + out_dict[k] = v + + if isinstance(self.data, Tensor): + data = _broadcast_tensor(self.data) + else: + data = [_broadcast_tensor(d) for d in self.data] + out_dict['data'] = data + + new_meta = None + if self.meta is not None: + new_meta = dict() + for k, v in self.meta.items(): + if isinstance(v, Tensor): + v = _broadcast_tensor(v) + self.meta[k] = v + elif hasattr(v, 'to_device'): + assert hasattr(v, 'broadcast') + v = v.broadcast() + self.meta[k] = v + new_meta[k] = v + + out_dict['meta'] = new_meta + return MultiModalTensor(**out_dict) + + +MultiModalInputs = Dict[str, List[MultiModalTensor]] diff --git a/lmdeploy/pytorch/multimodal/image_type.py b/lmdeploy/pytorch/multimodal/image_type.py new file mode 100644 index 0000000000..19211a381f --- /dev/null +++ b/lmdeploy/pytorch/multimodal/image_type.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass +from typing import Any, ClassVar, Dict + +from PIL import Image + +from .data_type import MultiModalData + + +@dataclass +class ImageData(MultiModalData): + data: Image + loc: int + meta: Dict[str, Any] = None + type: ClassVar[str] = 'image' diff --git a/lmdeploy/pytorch/nn/__init__.py b/lmdeploy/pytorch/nn/__init__.py index 63df9a5ae9..4705115bf4 100644 --- a/lmdeploy/pytorch/nn/__init__.py +++ b/lmdeploy/pytorch/nn/__init__.py @@ -2,7 +2,7 @@ # attention module is modified from: # https://github.com/vllm-project/vllm/blob/main/vllm/attention/ from .activation import GeluAndMul, SiluAndMul # noqa: F401 -from .attention import Attention # noqa: F401 +from .attention import Attention, FlashAttention # noqa: F401 from .norm import LayerNorm, RMSNorm # noqa: F401 from .rotary_embedding import ApplyRotaryEmb # noqa: F401 from .rotary_embedding import RopeType # noqa: F401 diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py index 26f1034d36..684c8122f8 100644 --- a/lmdeploy/pytorch/nn/attention.py +++ b/lmdeploy/pytorch/nn/attention.py @@ -9,6 +9,14 @@ from .utils import get_distribute_size +def _update_num_heads(num_heads: int, num_kv_heads: int): + """update heads.""" + world_size, rank = get_world_rank() + num_heads = get_distribute_size(num_heads, world_size, rank) + num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) + return num_heads, num_kv_heads + + class Attention(nn.Module): """Attention layer.""" @@ -22,15 +30,19 @@ def __init__( alibi: bool = False, sliding_window: int = None, logit_softcapping: float = None, - replicate_kv: bool = False, + causal: bool = True, **kwargs, ): super().__init__() - num_heads, num_kv_heads = self._update_num_heads( - num_heads, num_kv_heads, replicate_kv) + if num_kv_heads is None: + num_kv_heads = num_heads + if v_head_size is None: + v_head_size = head_size + num_heads, num_kv_heads = _update_num_heads(num_heads, num_kv_heads) layer_backend = get_backend() - impl_builder = layer_backend.get_layer_impl_builder(OpType.Attention) + impl_builder = layer_backend.get_layer_impl_builder( + OpType.PagedAttention) self.impl = impl_builder.build( num_heads=num_heads, @@ -41,18 +53,10 @@ def __init__( alibi=alibi, sliding_window=sliding_window, logit_softcapping=logit_softcapping, + causal=causal, **kwargs, ) - def _update_num_heads(self, num_heads: int, num_kv_heads: int, - replicate_kv: bool): - """update heads.""" - world_size, rank = get_world_rank() - num_heads = get_distribute_size(num_heads, world_size, rank) - if not replicate_kv: - num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) - return num_heads, num_kv_heads - def forward( self, query: torch.Tensor, @@ -77,3 +81,75 @@ def forward( v_scales_zeros=v_scales_zeros, inplace=inplace, ) + + +class FlashAttention(nn.Module): + """flash attention w/o paging.""" + + def __init__( + self, + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logit_softcapping: float = None, + **kwargs, + ): + super().__init__() + if num_kv_heads is None: + num_kv_heads = num_heads + if v_head_dim is None: + v_head_dim = head_dim + num_heads, num_kv_heads = _update_num_heads(num_heads, num_kv_heads) + + layer_backend = get_backend() + + impl_builder = layer_backend.get_layer_impl_builder( + OpType.FlashAttention) + + self.impl = impl_builder.build( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + v_head_dim=v_head_dim, + causal=causal, + sliding_window=sliding_window, + logit_softcapping=logit_softcapping, + **kwargs, + ) + + def forward(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + q_start_loc: torch.Tensor, + q_seqlens: torch.Tensor, + kv_start_loc: torch.Tensor = None, + kv_seqlens: torch.Tensor = None, + max_q_seqlen: int = None) -> torch.Tensor: + """forward.""" + + if max_q_seqlen is None: + max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) + + if kv_start_loc is None and kv_seqlens is None: + kv_start_loc = q_start_loc + kv_seqlens = q_seqlens + + assert kv_start_loc is not None + assert kv_seqlens is not None + + return self.impl.forward( + query, + key, + value, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_start_loc=kv_start_loc, + kv_seqlens=kv_seqlens, + max_q_seqlen=max_q_seqlen, + ) diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index 08040ee00c..486c684a3c 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -12,7 +12,7 @@ from ..backends import OpType, get_backend from ..backends.lora import AdapterInfo -from .utils import div_up, get_distribute_size +from .utils import get_distribute_size logger = get_logger('lmdeploy') @@ -32,9 +32,11 @@ def _chunk_align(weight: torch.Tensor, chunks: int, dim: int, align: int): size = weight.size(dim) assert size % align == 0 aligned_size = size // align - align_per_chunk = div_up(aligned_size, chunks) - sections = [align_per_chunk] * (chunks - 1) - sections += [aligned_size - align_per_chunk * (chunks - 1)] + + # try best to evenly split chunks + align_per_chunk = aligned_size // chunks + remain = aligned_size % chunks + sections = [align_per_chunk + int(c < remain) for c in range(chunks)] sections = [sec * align for sec in sections] return weight.split(sections, dim=dim) @@ -42,20 +44,24 @@ def _chunk_align(weight: torch.Tensor, chunks: int, dim: int, align: int): class QKVMixin: """qkv mixin.""" - def _get_qkv_out_features(self, num_q_heads: int, num_kv_heads: int, - head_size: int, head_size_v: int): + def _get_qkv_out_features(self, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + head_size_v: int, + num_replicate_kv_heads: int = 1): """get io features.""" - all_out_features = (num_q_heads * head_size, num_kv_heads * head_size, - num_kv_heads * head_size_v) + num_kv_heads_real = num_kv_heads // num_replicate_kv_heads + all_out_features = (num_q_heads * head_size, + num_kv_heads_real * head_size, + num_kv_heads_real * head_size_v) return all_out_features - def _update_num_heads(self, num_q_heads: int, num_kv_heads: int, - replicate_kv: bool): + def _update_num_heads(self, num_q_heads: int, num_kv_heads: int): """update num heads.""" world_size, rank = get_world_rank() num_q_heads = get_distribute_size(num_q_heads, world_size, rank) - if not replicate_kv: - num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) + num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) return num_q_heads, num_kv_heads @@ -212,7 +218,7 @@ def __init__( self.out_features = out_features self.w_bit = w_bit self.group_size = group_size - self.elem_per_int = 32 // self.w_bit + self.elem_per_int = 32 // w_bit self.lora_adapters = nn.ModuleDict() self.is_tp = is_tp self.colwise = colwise @@ -363,12 +369,9 @@ def __init__(self, w_bit: int, group_size: int, bias: bool, - replicate: Optional[List[bool]] = None, device: Optional[torch.device] = None, is_tp: bool = True, out_names: Optional[List[int]] = None): - if replicate is None: - replicate = tuple(False for _ in all_out_features) self.split_section_s = all_out_features elem_per_int = 32 // w_bit @@ -377,9 +380,8 @@ def __init__(self, ] all_out_features = self._update_all_out_features( - all_out_features, w_bit, group_size, replicate) + all_out_features, w_bit, group_size) self.all_out_features = all_out_features - self.replicate = replicate if out_names is None: out_names = torch.arange(len(self.all_out_features)).tolist() assert len(out_names) == len(self.all_out_features) @@ -414,15 +416,12 @@ def _get_io_features(self, in_features: int, out_features: int, w_bit: int, return in_features, out_features def _update_all_out_features(self, all_out_features: List[int], w_bit: int, - group_size: int, - replicate: Optional[List[bool]]): + group_size: int): """update all out features.""" world_size, rank = get_world_rank() new_all_out_features = [] align = max(32 // w_bit, group_size) - for out_feat, rep in zip(all_out_features, replicate): - if rep: - new_all_out_features.append(out_feat) + for out_feat in all_out_features: new_out_feat = get_distribute_size(out_feat, world_size, rank, align) new_all_out_features.append(new_out_feat) @@ -433,14 +432,11 @@ def weight_loader(self, param: torch.nn.Parameter, """weight loader.""" world_size, rank = get_world_rank() shard_idx = self.out_names_map[shard_id] - if loaded_weight.dim() == 1: # bias align = max(self.elem_per_int, self.group_size) param_w = param.data.split(self.all_out_features, 0)[shard_idx] - if not self.replicate[shard_idx]: - weight = _chunk_align(loaded_weight, world_size, 0, - align)[rank] + weight = _chunk_align(loaded_weight, world_size, 0, align)[rank] param_w.copy_(weight) if param._weight_type in ['scales', 'bias']: @@ -456,8 +452,7 @@ def weight_loader(self, param: torch.nn.Parameter, ] param_w = param.data.split(quanted_out_feats, 1)[shard_idx] - if not self.replicate[shard_idx]: - weight = _chunk_align(loaded_weight, world_size, -1, align)[rank] + weight = _chunk_align(loaded_weight, world_size, -1, align)[rank] param_w.copy_(weight) def weight_spliter_wz(self, loaded_weight: torch.Tensor): @@ -480,45 +475,82 @@ def __init__(self, head_size_v: int, w_bit: int, group_size: int, - replicate_kv: bool = False, bias: bool = False, device: Optional[torch.device] = None, - is_tp: bool = True): + is_tp: bool = True, + num_replicate_kv_heads: int = 1): self.qkv_split_section_s = self._get_qkv_out_features( - num_q_heads, num_kv_heads, head_size, head_size_v) + num_q_heads, num_kv_heads, head_size, head_size_v, + num_replicate_kv_heads) elem_per_int = 32 // w_bit self.qkv_split_section_wz = [ size // elem_per_int for size in self.qkv_split_section_s ] num_q_heads, num_kv_heads = self._update_num_heads( - num_q_heads, num_kv_heads, replicate_kv) + num_q_heads, num_kv_heads) all_out_features = self._get_qkv_out_features(num_q_heads, num_kv_heads, head_size, head_size_v) - replicate = (False, replicate_kv, replicate_kv) out_names = ('q', 'k', 'v') self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads self.head_size = head_size self.head_size_v = head_size_v + self.num_replicate_kv_heads = num_replicate_kv_heads + super().__init__(in_features, all_out_features, w_bit=w_bit, group_size=group_size, bias=bias, - replicate=replicate, device=device, is_tp=is_tp, out_names=out_names) def _update_all_out_features(self, all_out_features: List[int], w_bit: int, - group_size: int, - replicate: Optional[List[bool]]): + group_size: int): """update all out features.""" return all_out_features + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, shard_id: Any): + """weight loader.""" + world_size, rank = get_world_rank() + chunk_size, chunk_idx = world_size, rank + shard_idx = self.out_names_map[shard_id] + + if self.num_replicate_kv_heads > 1 and shard_id in ['k', 'v']: + # update to duplicate k/v for tp_size > num_kv_heads + chunk_size = world_size // self.num_replicate_kv_heads + chunk_idx = rank // self.num_replicate_kv_heads + + if loaded_weight.dim() == 1: + # bias + align = max(self.elem_per_int, self.group_size) + param_w = param.data.split(self.all_out_features, 0)[shard_idx] + weight = _chunk_align(loaded_weight, chunk_size, 0, + align)[chunk_idx] + param_w.copy_(weight) + return + + if param._weight_type in ['scales', 'bias']: + # scales + align = max(self.elem_per_int, self.group_size) + param_w = param.data.split(self.all_out_features, -1)[shard_idx] + else: + # qweight or qzeros + align = max(self.elem_per_int, + self.group_size) // self.elem_per_int + quanted_out_feats = [ + feat // self.elem_per_int for feat in self.all_out_features + ] + param_w = param.data.split(quanted_out_feats, 1)[shard_idx] + + weight = _chunk_align(loaded_weight, chunk_size, -1, align)[chunk_idx] + param_w.copy_(weight) + def weight_spliter_wz(self, loaded_weight: torch.Tensor, layout: str = 'default'): @@ -710,18 +742,13 @@ def __init__(self, in_features: int, all_out_features: List[int], bias: bool, - replicate: Optional[List[bool]] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = True, out_names: Optional[List[int]] = None): - if replicate is None: - replicate = tuple(False for _ in all_out_features) self.split_section = all_out_features - all_out_features = self._update_all_out_features( - all_out_features, replicate) + all_out_features = self._update_all_out_features(all_out_features) self.all_out_features = all_out_features - self.replicate = replicate if out_names is None: out_names = torch.arange(len(self.all_out_features)).tolist() assert len(out_names) == len(self.all_out_features) @@ -748,14 +775,11 @@ def _get_io_features(self, in_features: int, out_features: int, """get io features.""" return in_features, out_features - def _update_all_out_features(self, all_out_features: List[int], - replicate: Optional[List[bool]]): + def _update_all_out_features(self, all_out_features: List[int]): """update all out features.""" world_size, rank = get_world_rank() new_all_out_features = [] - for out_feat, rep in zip(all_out_features, replicate): - if rep: - new_all_out_features.append(out_feat) + for out_feat in all_out_features: new_out_feat = get_distribute_size(out_feat, world_size, rank) new_all_out_features.append(new_out_feat) return new_all_out_features @@ -766,8 +790,7 @@ def weight_loader(self, param: torch.nn.Parameter, world_size, rank = get_world_rank() shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] - if not self.replicate[shard_idx]: - loaded_weight = loaded_weight.chunk(world_size, 0)[rank] + loaded_weight = loaded_weight.chunk(world_size, 0)[rank] param_w.copy_(loaded_weight) def weight_spliter(self, loaded_weight: torch.Tensor): @@ -787,38 +810,57 @@ def __init__(self, num_kv_heads: int, head_size: int, head_size_v: int, - replicate_kv: bool = False, bias: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, - is_tp: bool = True): + is_tp: bool = True, + num_replicate_kv_heads: int = 1): self.qkv_split_section = self._get_qkv_out_features( - num_q_heads, num_kv_heads, head_size, head_size_v) + num_q_heads, num_kv_heads, head_size, head_size_v, + num_replicate_kv_heads) num_q_heads, num_kv_heads = self._update_num_heads( - num_q_heads, num_kv_heads, replicate_kv) + num_q_heads, num_kv_heads) all_out_features = self._get_qkv_out_features(num_q_heads, num_kv_heads, head_size, head_size_v) - replicate = (False, replicate_kv, replicate_kv) out_names = ('q', 'k', 'v') self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads self.head_size = head_size self.head_size_v = head_size_v + self.num_replicate_kv_heads = num_replicate_kv_heads super().__init__(in_features, all_out_features, bias=bias, - replicate=replicate, dtype=dtype, device=device, is_tp=is_tp, out_names=out_names) - def _update_all_out_features(self, all_out_features: List[int], - replicate: Optional[List[bool]]): + def _update_all_out_features(self, all_out_features: List[int]): """update all out features.""" return all_out_features + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, shard_id: Any): + """weight loader.""" + _, rank = get_world_rank() + shard_idx = self.out_names_map[shard_id] + param_w = param.data.split(self.all_out_features, 0)[shard_idx] + num_head = self.num_q_heads if shard_id == 'q' \ + else self.num_kv_heads + head_dim = self.head_size if shard_id in ['q', 'k'] \ + else self.head_size_v + # update to duplicate k/v for tp_size > num_kv_heads + rank_idx = rank if shard_id == 'q' \ + else rank // self.num_replicate_kv_heads + sec_start = rank_idx * num_head * head_dim + sec_len = num_head * head_dim + loaded_weight = loaded_weight.narrow(dim=0, + start=sec_start, + length=sec_len) + param_w.copy_(loaded_weight) + def weight_spliter(self, loaded_weight: torch.Tensor, layout: str = 'default'): @@ -986,18 +1028,13 @@ def __init__(self, in_features: int, all_out_features: List[int], bias: bool, - replicate: Optional[List[bool]] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = True, out_names: Optional[List[int]] = None): - if replicate is None: - replicate = tuple(False for _ in all_out_features) self.split_section = all_out_features - all_out_features = self._update_all_out_features( - all_out_features, replicate) + all_out_features = self._update_all_out_features(all_out_features) self.all_out_features = all_out_features - self.replicate = replicate if out_names is None: out_names = torch.arange(len(self.all_out_features)).tolist() assert len(out_names) == len(self.all_out_features) @@ -1022,14 +1059,11 @@ def _get_io_features(self, in_features: int, out_features: int, """get io features.""" return in_features, out_features - def _update_all_out_features(self, all_out_features: List[int], - replicate: Optional[List[bool]]): + def _update_all_out_features(self, all_out_features: List[int]): """update all out features.""" world_size, rank = get_world_rank() new_all_out_features = [] - for out_feat, rep in zip(all_out_features, replicate): - if rep: - new_all_out_features.append(out_feat) + for out_feat in all_out_features: new_out_feat = get_distribute_size(out_feat, world_size, rank) new_all_out_features.append(new_out_feat) return new_all_out_features @@ -1040,8 +1074,7 @@ def weight_loader(self, param: torch.nn.Parameter, world_size, rank = get_world_rank() shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] - if not self.replicate[shard_idx]: - loaded_weight = loaded_weight.chunk(world_size, 0)[rank] + loaded_weight = loaded_weight.chunk(world_size, 0)[rank] param_w.copy_(loaded_weight) def weight_spliter(self, loaded_weight: torch.Tensor): @@ -1061,35 +1094,36 @@ def __init__(self, num_kv_heads: int, head_size: int, head_size_v: int, - replicate_kv: bool = False, bias: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, - is_tp: bool = True): + is_tp: bool = True, + num_replicate_kv_heads: int = 1): + self.qkv_split_section = self._get_qkv_out_features( - num_q_heads, num_kv_heads, head_size, head_size_v) + num_q_heads, num_kv_heads, head_size, head_size_v, + num_replicate_kv_heads) num_q_heads, num_kv_heads = self._update_num_heads( - num_q_heads, num_kv_heads, replicate_kv) + num_q_heads, num_kv_heads) all_out_features = self._get_qkv_out_features(num_q_heads, num_kv_heads, head_size, head_size_v) - replicate = (False, replicate_kv, replicate_kv) out_names = ('q', 'k', 'v') self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads self.head_size = head_size self.head_size_v = head_size_v + self.num_replicate_kv_heads = num_replicate_kv_heads + super().__init__(in_features, all_out_features, bias=bias, - replicate=replicate, dtype=dtype, device=device, is_tp=is_tp, out_names=out_names) - def _update_all_out_features(self, all_out_features: List[int], - replicate: Optional[List[bool]]): + def _update_all_out_features(self, all_out_features: List[int]): """update all out features.""" return all_out_features @@ -1097,15 +1131,20 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """weight loader.""" world_size, rank = get_world_rank() + chunk_size, chunk_idx = world_size, rank shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] - if not self.replicate[shard_idx]: - if shard_idx in [0, 1]: - loaded_weight = _chunk_align(loaded_weight, world_size, 0, - self.head_size)[rank] - if shard_idx == 2: - loaded_weight = _chunk_align(loaded_weight, world_size, 0, - self.head_size_v)[rank] + + if self.num_replicate_kv_heads > 1 and shard_id in ['k', 'v']: + # update to duplicate k/v for tp_size > num_kv_heads + chunk_size = world_size // self.num_replicate_kv_heads + chunk_idx = rank // self.num_replicate_kv_heads + if shard_idx in [0, 1]: + loaded_weight = _chunk_align(loaded_weight, chunk_size, 0, + self.head_size)[chunk_idx] + elif shard_idx == 2: + loaded_weight = _chunk_align(loaded_weight, chunk_size, 0, + self.head_size_v)[chunk_idx] param_w.copy_(loaded_weight) def weight_spliter(self, @@ -1291,12 +1330,12 @@ def build_qkv_proj(in_features: int, num_kv_heads: int, head_size: int, head_size_v: int = None, - replicate_kv: bool = False, bias: bool = False, quant_config: Any = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, - is_tp: bool = True): + is_tp: bool = True, + num_replicate_kv_heads: int = 1): """build qkv proj.""" if is_tp: world_size, _ = get_world_rank() @@ -1306,48 +1345,42 @@ def build_qkv_proj(in_features: int, head_size_v = head_size if quant_config is None: - return QKVBaseLinear( - in_features=in_features, - num_q_heads=num_q_heads, - num_kv_heads=num_kv_heads, - head_size=head_size, - head_size_v=head_size_v, - replicate_kv=replicate_kv, - bias=bias, - dtype=dtype, - device=device, - is_tp=is_tp, - ) + return QKVBaseLinear(in_features=in_features, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + head_size_v=head_size_v, + bias=bias, + dtype=dtype, + device=device, + is_tp=is_tp, + num_replicate_kv_heads=num_replicate_kv_heads) quant_method = quant_config['quant_method'] if quant_method == 'awq': w_bit = quant_config.get('bits', 4) group_size = quant_config.get('group_size', 128) - return QKVAwqLinear( - in_features=in_features, - num_q_heads=num_q_heads, - num_kv_heads=num_kv_heads, - head_size=head_size, - head_size_v=head_size_v, - replicate_kv=replicate_kv, - w_bit=w_bit, - group_size=group_size, - bias=bias, - device=device, - is_tp=is_tp, - ) + return QKVAwqLinear(in_features=in_features, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + head_size_v=head_size_v, + w_bit=w_bit, + group_size=group_size, + bias=bias, + device=device, + is_tp=is_tp, + num_replicate_kv_heads=num_replicate_kv_heads) if quant_method == 'smooth_quant': - return QKVW8A8Linear( - in_features=in_features, - num_q_heads=num_q_heads, - num_kv_heads=num_kv_heads, - head_size=head_size, - head_size_v=head_size_v, - replicate_kv=replicate_kv, - bias=bias, - dtype=dtype, - device=device, - is_tp=is_tp, - ) + return QKVW8A8Linear(in_features=in_features, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + head_size_v=head_size_v, + bias=bias, + dtype=dtype, + device=device, + is_tp=is_tp, + num_replicate_kv_heads=num_replicate_kv_heads) else: raise RuntimeError(f'Unsupported quant method: {quant_method}') diff --git a/lmdeploy/pytorch/nn/utils.py b/lmdeploy/pytorch/nn/utils.py index 3289f858a7..3b60ca21de 100644 --- a/lmdeploy/pytorch/nn/utils.py +++ b/lmdeploy/pytorch/nn/utils.py @@ -11,7 +11,10 @@ def get_distribute_size(feature_size: int, """update feature size.""" assert feature_size % align == 0 aligned_size = feature_size // align - align_per_rank = div_up(aligned_size, world_size) - prev_feats = align_per_rank * rank - updated_aligned_size = min(align_per_rank, aligned_size - prev_feats) + # try to make every rank has same amount of feats + updated_aligned_size = aligned_size // world_size + # if there are still some remain, given them to + # each rank + if rank < aligned_size % world_size: + updated_aligned_size += 1 return updated_aligned_size * align diff --git a/lmdeploy/pytorch/supported_models.py b/lmdeploy/pytorch/supported_models.py index 7fa568651b..67452f78e3 100644 --- a/lmdeploy/pytorch/supported_models.py +++ b/lmdeploy/pytorch/supported_models.py @@ -47,9 +47,9 @@ # cogvlm-chat CogVLMForCausalLM=True, # llava - LlavaLlamaForCausalLM=True, + LlavaLlamaForCausalLM=False, # llava mistral - LlavaMistralForCausalLM=True, + LlavaMistralForCausalLM=False, # deepseekvl MultiModalityCausalLM=False, # StarCoder2 diff --git a/lmdeploy/pytorch/tools/make_inputs.py b/lmdeploy/pytorch/tools/make_inputs.py index f2d23830b7..053e7d0918 100644 --- a/lmdeploy/pytorch/tools/make_inputs.py +++ b/lmdeploy/pytorch/tools/make_inputs.py @@ -135,6 +135,7 @@ def __fill_kv_caches(kv_caches, past_key_values, block_offsets): return StepContext.new( inputs=model_inputs, + model_config=model_config, world_size=world_size, kv_caches=kv_caches, ) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index e9083fc110..0e1a01f4f4 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -325,9 +325,10 @@ def __call__(self, """Inference a batch of prompts. Args: - prompts (List[str] | str | List[Dict] | List[Dict]): a batch of - prompts. It accepts: string prompt, a list of string prompts, - a chat history in OpenAI format or a list of chat history. + prompts (List[str] | str | List[Dict] | List[List[Dict]]]): a + batch of prompts. It accepts: string prompt, a list of string + prompts, a chat history in OpenAI format or a list of chat + history. gen_config (GenerationConfig | None): a instance of GenerationConfig. Default to None. do_preprocess (bool): whether pre-process the messages. Default to @@ -471,9 +472,10 @@ def batch_infer(self, """Inference a batch of prompts. Args: - prompts (List[str] | str | List[Dict] | List[Dict]): a batch of - prompts. It accepts: string prompt, a list of string prompts, - a chat history in OpenAI format or a list of chat history. + prompts (List[str] | str | List[Dict] | List[List[Dict]]]): a + batch of prompts. It accepts: string prompt, a list of string + prompts, a chat history in OpenAI format or a list of chat + history. gen_config (GenerationConfig | None): a instance of or a list of GenerationConfig. Default to None. do_preprocess (bool): whether pre-process the messages. Default to @@ -516,9 +518,10 @@ def stream_infer( """Inference a batch of prompts with stream mode. Args: - prompts (List[str] | str | List[Dict] | List[Dict]): a batch of - prompts. It accepts: string prompt, a list of string prompts, - a chat history in OpenAI format or a list of chat history. + prompts (List[str] | str | List[Dict] | List[List[Dict]]]):a + batch of prompts. It accepts: string prompt, a list of string + prompts, a chat history in OpenAI format or a list of chat + history. gen_config (GenerationConfig | None): a instance of or a list of GenerationConfig. Default to None. do_preprocess (bool): whether pre-process the messages. Default to @@ -622,8 +625,8 @@ async def generate( if gen_config.stop_token_ids is None: gen_config.stop_token_ids = self.stop_words if not gen_config.do_sample: - logger.warn(f'GenerationConfig: {gen_config}') - logger.warn( + logger.warning(f'GenerationConfig: {gen_config}') + logger.warning( 'Since v0.6.0, lmdeploy add `do_sample` in ' 'GenerationConfig. It defaults to False, meaning greedy ' 'decoding. Please set `do_sample=True` if sampling ' diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 2d0560720d..cce9567896 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -946,6 +946,20 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return JSONResponse(ret) +def handle_torchrun(): + """To disable mmengine logging logic when using torchrun.""" + + def dummy_get_device_id(): + return 0 + + if int(os.environ.get('LOCAL_RANK', -1)) > 0: + from lmdeploy.vl.model.utils import _set_func + + # the replacement can't be recovered + _set_func('mmengine.logging.logger._get_device_id', + dummy_get_device_id) + + @router.on_event('startup') async def startup_event(): if VariableInterface.proxy_url is None: @@ -1069,8 +1083,8 @@ def serve(model_path: str, ssl_certfile = os.environ['SSL_CERTFILE'] http_or_https = 'https' + handle_torchrun() _, pipeline_class = get_task(model_path) - VariableInterface.async_engine = pipeline_class( model_path=model_path, model_name=model_name, diff --git a/lmdeploy/serve/proxy/constants.py b/lmdeploy/serve/proxy/constants.py index 88d86a3e33..5bf6e67659 100644 --- a/lmdeploy/serve/proxy/constants.py +++ b/lmdeploy/serve/proxy/constants.py @@ -2,8 +2,8 @@ import enum -LATENCY_DEEQUE_LEN = 15 -API_TIMEOUT_LEN = 100 +LATENCY_DEQUE_LEN = 15 +API_READ_TIMEOUT = 100 class Strategy(enum.Enum): diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index 5f05930bd0..392ede3267 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import asyncio import copy import json import os @@ -18,14 +19,15 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, Field +from requests.exceptions import RequestException from lmdeploy.serve.openai.api_server import (check_api_key, create_error_response) from lmdeploy.serve.openai.protocol import ( # noqa: E501 ChatCompletionRequest, CompletionRequest, ModelCard, ModelList, ModelPermission) -from lmdeploy.serve.proxy.constants import (API_TIMEOUT_LEN, - LATENCY_DEEQUE_LEN, ErrorCodes, +from lmdeploy.serve.proxy.constants import (API_READ_TIMEOUT, + LATENCY_DEQUE_LEN, ErrorCodes, Strategy, err_msg) from lmdeploy.utils import get_logger @@ -36,7 +38,7 @@ class Status(BaseModel): """Status protocol consists of models' information.""" models: Optional[List[str]] = Field(default=[], examples=[[]]) unfinished: int = 0 - latency: Deque = Field(default=deque(maxlen=LATENCY_DEEQUE_LEN), + latency: Deque = Field(default=deque(maxlen=LATENCY_DEQUE_LEN), examples=[[]]) speed: Optional[int] = Field(default=None, examples=[None]) @@ -87,6 +89,9 @@ def __init__(self, with open(self.config_path, 'r') as config_file: self.nodes = yaml.safe_load(config_file)['nodes'] for url, status in self.nodes.items(): + latency = deque(status.get('latency', []), + maxlen=LATENCY_DEQUE_LEN) + status['latency'] = latency status = Status(**status) self.nodes[url] = status self.heart_beat_thread = threading.Thread(target=heart_beat_controller, @@ -99,7 +104,7 @@ def update_config_file(self): nodes = copy.deepcopy(self.nodes) for url, status in nodes.items(): nodes[url] = status.model_dump() - nodes[url]['latency'] = list(status.latency) + nodes[url]['latency'] = list(status.latency)[-LATENCY_DEQUE_LEN:] with open(self.config_path, 'w') as config_file: # update cfg yml yaml.dump(dict(nodes=nodes), config_file) @@ -149,7 +154,8 @@ def remove_stale_nodes_by_expiration(self): to_be_deleted.append(node_url) for node_url in to_be_deleted: self.remove(node_url) - logger.info(f'Removed node_url: {node_url}') + logger.info(f'Removed node_url: {node_url} ' + 'due to heart beat expiration') @property def model_list(self): @@ -251,7 +257,7 @@ def handle_unavailable_model(self, model_name): Args: model_name (str): the model in the request. """ - logger.info(f'no model name: {model_name}') + logger.warning(f'no model name: {model_name}') ret = { 'error_code': ErrorCodes.MODEL_NOT_FOUND, 'text': err_msg[ErrorCodes.MODEL_NOT_FOUND], @@ -260,51 +266,54 @@ def handle_unavailable_model(self, model_name): def handle_api_timeout(self, node_url): """Handle the api time out.""" - logger.info(f'api timeout: {node_url}') + logger.warning(f'api timeout: {node_url}') ret = { - 'error_code': ErrorCodes.API_TIMEOUT, + 'error_code': ErrorCodes.API_TIMEOUT.value, 'text': err_msg[ErrorCodes.API_TIMEOUT], } return json.dumps(ret).encode() + b'\n' - def stream_generate(self, request: Dict, node_url: str, node_path: str): + def stream_generate(self, request: Dict, node_url: str, endpoint: str): """Return a generator to handle the input request. Args: request (Dict): the input request. node_url (str): the node url. - node_path (str): the node path. Such as `/v1/chat/completions`. + endpoint (str): the endpoint. Such as `/v1/chat/completions`. """ try: response = requests.post( - node_url + node_path, + node_url + endpoint, json=request, - stream=request['stream'], - timeout=API_TIMEOUT_LEN, + stream=True, + timeout=(5, API_READ_TIMEOUT), ) for chunk in response.iter_lines(decode_unicode=False, delimiter=b'\n'): if chunk: yield chunk + b'\n\n' - except requests.exceptions.RequestException as e: # noqa + except (Exception, GeneratorExit, RequestException) as e: # noqa + logger.error(f'catched an exception: {e}') + # exception happened, reduce unfinished num yield self.handle_api_timeout(node_url) - async def generate(self, request: Dict, node_url: str, node_path: str): + async def generate(self, request: Dict, node_url: str, endpoint: str): """Return a the response of the input request. Args: request (Dict): the input request. node_url (str): the node url. - node_path (str): the node path. Such as `/v1/chat/completions`. + endpoint (str): the endpoint. Such as `/v1/chat/completions`. """ try: import httpx async with httpx.AsyncClient() as client: - response = await client.post(node_url + node_path, + response = await client.post(node_url + endpoint, json=request, - timeout=API_TIMEOUT_LEN) + timeout=API_READ_TIMEOUT) return response.text - except requests.exceptions.RequestException as e: # noqa + except (Exception, GeneratorExit, RequestException, asyncio.CancelledError) as e: # noqa # yapf: disable + logger.error(f'catched an exception: {e}') return self.handle_api_timeout(node_url) def pre_call(self, node_url): @@ -381,7 +390,11 @@ def add_node(node: Node, raw_request: Request = None): RPM or other metric. All the values of nodes should be the same metric. """ try: - node_manager.add(node.url, node.status) + res = node_manager.add(node.url, node.status) + if res is not None: + logger.error(f'add node {node.url} failed, {res}') + return res + logger.info(f'add node {node.url} successfully') return 'Added successfully' except: # noqa return 'Failed to add, please check the input url.' @@ -392,8 +405,10 @@ def remove_node(node_url: str): """Show available models.""" try: node_manager.remove(node_url) + logger.info(f'delete node {node_url} successfully') return 'Deleted successfully' except: # noqa + logger.error(f'delete node {node_url} failed.') return 'Failed to delete, please check the input url.' @@ -407,28 +422,50 @@ async def chat_completions_v1(request: ChatCompletionRequest, The request should be a JSON object with the following fields: - model: model name. Available from /v1/models. - - messages: string prompt or chat history in OpenAI format. A example - for chat history is `[{"role": "user", "content":"knock knock"}]`. + - messages: string prompt or chat history in OpenAI format. Chat history + example: `[{"role": "user", "content": "hi"}]`. - temperature (float): to modulate the next token probability - top_p (float): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. - n (int): How many chat completion choices to generate for each input - message. Only support one here. + message. **Only support one here**. - stream: whether to stream the results or not. Default to false. - - max_tokens (int): output token nums + - max_tokens (int | None): output token nums. Default to None. - repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty - stop (str | List[str] | None): To stop generating further tokens. Only accept stop words that's encoded to one token idex. + - response_format (Dict | None): Only pytorch backend support formatting + response. Examples: `{"type": "json_schema", "json_schema": {"name": + "test","schema": {"properties": {"name": {"type": "string"}}, + "required": ["name"], "type": "object"}}}` + or `{"type": "regex_schema", "regex_schema": "call me [A-Za-z]{1,10}"}` + - logit_bias (Dict): Bias to logits. Only supported in pytorch engine. + - tools (List): A list of tools the model may call. Currently, only + internlm2 functions are supported as a tool. Use this to specify a + list of functions for which the model can generate JSON inputs. + - tool_choice (str | object): Controls which (if any) tool is called by + the model. `none` means the model will not call any tool and instead + generates a message. Specifying a particular tool via {"type": + "function", "function": {"name": "my_function"}} forces the model to + call that tool. `auto` or `required` will put all the tools information + to the model. Additional arguments supported by LMDeploy: + - top_k (int): The number of the highest probability vocabulary + tokens to keep for top-k-filtering - ignore_eos (bool): indicator for ignoring eos - - session_id (int): if not specified, will set random value + - skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be True. + - min_new_tokens (int): To generate at least numbers of tokens. + - min_p (float): Minimum token probability, which will be scaled by the + probability of the most likely token. It must be a value between + 0 and 1. Typical values are in the 0.01-0.2 range, comparably + selective as setting `top_p` in the 0.99-0.8 range (use the + opposite of normal `top_p` values) Currently we do not support the following features: - - function_call (Users should implement this by themselves) - - logit_bias (not supported yet) - presence_penalty (replaced with repetition_penalty) - frequency_penalty (replaced with repetition_penalty) """ @@ -439,6 +476,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, if not node_url: return node_manager.handle_unavailable_model(request.model) + logger.info(f'A request is dispatched to {node_url}') request_dict = request.model_dump() start = node_manager.pre_call(node_url) if request.stream is True: @@ -465,13 +503,13 @@ async def completions_v1(request: CompletionRequest, - model (str): model name. Available from /v1/models. - prompt (str): the input prompt. - suffix (str): The suffix that comes after a completion of inserted text. - - max_tokens (int): output token nums + - max_tokens (int): output token nums. Default to 16. - temperature (float): to modulate the next token probability - top_p (float): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. - n (int): How many chat completion choices to generate for each input - message. Only support one here. + message. **Only support one here**. - stream: whether to stream the results or not. Default to false. - repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty @@ -481,7 +519,8 @@ async def completions_v1(request: CompletionRequest, Additional arguments supported by LMDeploy: - ignore_eos (bool): indicator for ignoring eos - - session_id (int): if not specified, will set random value + - skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be True. - top_k (int): The number of the highest probability vocabulary tokens to keep for top-k-filtering @@ -497,6 +536,7 @@ async def completions_v1(request: CompletionRequest, if not node_url: return node_manager.handle_unavailable_model(request.model) + logger.info(f'A request is dispatched to {node_url}') request_dict = request.model_dump() start = node_manager.pre_call(node_url) if request.stream is True: @@ -517,6 +557,7 @@ def proxy(server_name: str = '0.0.0.0', 'min_observed_latency'] = 'min_expected_latency', api_keys: Optional[Union[List[str], str]] = None, ssl: bool = False, + log_level: str = 'INFO', **kwargs): """To launch the proxy server. @@ -540,6 +581,7 @@ def proxy(server_name: str = '0.0.0.0', if ssl: ssl_keyfile = os.environ['SSL_KEYFILE'] ssl_certfile = os.environ['SSL_CERTFILE'] + logger.setLevel(log_level) uvicorn.run(app=app, host=server_name, port=server_port, diff --git a/lmdeploy/serve/vl_async_engine.py b/lmdeploy/serve/vl_async_engine.py index c293cd71c8..becf1b76fb 100644 --- a/lmdeploy/serve/vl_async_engine.py +++ b/lmdeploy/serve/vl_async_engine.py @@ -1,148 +1,208 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Union +import asyncio +from typing import Dict, List, Literal, Optional, Tuple, Union -import numpy as np +import PIL +from lmdeploy.messages import (PytorchEngineConfig, TurbomindEngineConfig, + VisionConfig) from lmdeploy.pytorch.check_env import try_import_deeplink from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.utils import get_logger -from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX, IMAGE_TOKEN from lmdeploy.vl.engine import ImageEncoder -from lmdeploy.vl.templates import VLPromptType, get_vl_prompt_template +from lmdeploy.vl.utils import load_image logger = get_logger('lmdeploy') +VLPromptType = Union[str, Tuple[str, PIL.Image.Image], + Tuple[str, List[PIL.Image.Image]]] + class VLAsyncEngine(AsyncEngine): """Visual Language Async inference engine.""" - def __init__(self, model_path: str, **kwargs) -> None: - vision_config = kwargs.pop('vision_config', None) - backend_config = kwargs.get('backend_config', None) - if kwargs.get('backend', '') == 'pytorch': + def __init__(self, + model_path: str, + backend: Literal['turbomind', 'pytorch'] = 'turbomind', + backend_config: Optional[Union[TurbomindEngineConfig, + PytorchEngineConfig]] = None, + vision_config: Optional[VisionConfig] = None, + **kwargs) -> None: + if backend == 'pytorch': try_import_deeplink(backend_config.device_type) self.vl_encoder = ImageEncoder(model_path, + backend, vision_config, backend_config=backend_config) - super().__init__(model_path, **kwargs) + super().__init__(model_path, + backend=backend, + backend_config=backend_config, + **kwargs) if self.model_name == 'base': raise RuntimeError( 'please specify chat template as guided in https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html#set-chat-template' # noqa: E501 ) - self.vl_prompt_template = get_vl_prompt_template( - model_path, self.chat_template, self.model_name) - def _convert_prompts(self, + @classmethod + def _convert_prompts(cls, prompts: Union[VLPromptType, List[Dict], List[VLPromptType], List[List[Dict]]]): - """convert prompts to openai format.""" + """convert prompts to openai GPT4V format.""" if isinstance(prompts, str) or isinstance(prompts, tuple): - _prompts = self.vl_prompt_template.prompt_to_messages(prompts) + _prompts = cls.prompt_to_messages(prompts) elif isinstance(prompts[0], tuple) or isinstance(prompts[0], str): - _prompts = [ - self.vl_prompt_template.prompt_to_messages(x) for x in prompts - ] + _prompts = [cls.prompt_to_messages(x) for x in prompts] else: _prompts = prompts return _prompts async def _get_prompt_input(self, - prompt: Dict, + messages: Union[str, List[Dict]], do_preprocess: bool, sequence_start: bool, adapter_name: str, tools: Optional[List[object]] = None, **kwargs): - """get input_ids, embeddings and offsets.""" - if do_preprocess: - decorated = self.vl_prompt_template.messages2prompt( - prompt, sequence_start) - else: - decorated = prompt - segs = decorated.split(IMAGE_TOKEN) - - results = {} - input_ids = [] - from lmdeploy.vl.templates import (MllamaTempateWrapper, - MolmoChatTemplateWrapper, - Qwen2VLChatTemplateWrapper) - ranges = None - grid_thws = None - if len(segs) > 1: - # yapf: disable - images_with_kwargs = await self.vl_prompt_template.async_collect_pil_images(prompt) # noqa: E501 - # yapf: enable - features = [] - if len(images_with_kwargs) > 0: - images, image_kwargs = list(zip(*images_with_kwargs)) - features = await self.vl_encoder.async_infer( - images, image_kwargs) - - from lmdeploy.vl.templates import MiniCPMVTempateWrapper - if isinstance(self.vl_prompt_template, MiniCPMVTempateWrapper): - decorated, features = self.vl_prompt_template.update_image_token( # noqa: E501 - decorated, features) - segs = decorated.split(IMAGE_TOKEN) - - if isinstance(self.vl_prompt_template, - Qwen2VLChatTemplateWrapper): - grid_thws = [x['grid_thw'] for x in features] - features = [x['embeddings'] for x in features] - - if isinstance(self.vl_prompt_template, MllamaTempateWrapper): - # llama3.2 just encode <|image|> and inference - decorated = decorated.replace(IMAGE_TOKEN, '<|image|>') - input_ids = self.tokenizer.encode(decorated, - add_bos=sequence_start) - results['input_ids'] = input_ids - results['prompt'] = decorated - assert len(features) - results['cross_attention_states'] = features[0] - return results - - if isinstance(self.vl_prompt_template, - MolmoChatTemplateWrapper): - return features[0] - - features = [x.cpu().numpy() for x in features] - input_ids = [] - begins = [] - ends = [] - if len(segs) != len(features) + 1: - logger.error( - f'the number of {IMAGE_TOKEN} is not equal ' - f'to input images, {len(segs) - 1} vs {len(features)}') - features = features[:len(segs) - 1] - for i, seg in enumerate(segs): - if i > 0 and i <= len(features): - image_dim = features[i - 1].shape[0] - begins.append(len(input_ids)) - ends.append(begins[-1] + image_dim) - input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim) - seg_ids = self.tokenizer.encode(seg, - add_bos=((i == 0) - and sequence_start)) - input_ids.extend(seg_ids) - ranges = np.stack([begins, ends], axis=1).tolist() - results['input_embeddings'] = features or None - results['input_embedding_ranges'] = ranges or None + """process messages and return the required data for the inference + engines. + + Refer to pytorch.engine.EngineInstance.async_stream_infer and + turbomind.TurboMindInstance.async_stream_infer for the argument + specification. + """ + if isinstance(messages, str): + return await super()._get_prompt_input(messages, do_preprocess, + sequence_start, + adapter_name, tools, + **kwargs) + elif isinstance(messages, List): + has_multimodal_input = any( + isinstance(message['content'], list) and any( + item['type'] in ['image_url', 'image_data'] + for item in message['content']) for message in messages) + if not has_multimodal_input: + return await super()._get_prompt_input(messages, do_preprocess, + sequence_start, + adapter_name, tools, + **kwargs) else: - input_ids = self.tokenizer.encode(decorated, - add_bos=sequence_start) - - if isinstance(self.vl_prompt_template, Qwen2VLChatTemplateWrapper): - # TODO: refactor _get_prompt_input function - mrope_position_ids, mrope_position_delta = \ - self.vl_prompt_template.get_mrope_info( - len(input_ids), grid_thws=grid_thws, - embedding_ranges=ranges) - results['mrope_position_ids'] = mrope_position_ids - results['mrope_position_delta'] = mrope_position_delta - - results['input_ids'] = input_ids - results['prompt'] = decorated + raise RuntimeError(f'unsupported messages {messages}') + + messages = await self.async_convert_to_pil_images(messages) + results = await self.vl_encoder.preprocess(messages) + if self.backend == 'turbomind': + # for tm engine, this module perform vision embedding after image + # preprocessing. It utilizes the hf model's vision embeddings + # functions and returns the input_ids, input_embeddings, + # embedding_ranges and so on. All the returned values are passed + # to tm engine for token generation + results = await self.vl_encoder.async_infer(results) + results = await self.vl_encoder.wrap_for_turbomind( + results, self.chat_template, self.tokenizer, sequence_start) + elif self.backend == 'pytorch': + # for pt engine, this module only conduct the image preprocessing + # It leaves the vision embedding to the pt engine + results = await self.vl_encoder.wrap_for_pytorch( + results, self.chat_template, self.tokenizer, sequence_start) return results + @classmethod + async def async_convert_to_pil_images(cls, + messages: List[Dict]) -> List[Dict]: + """Scan the provided messages to find image URLs or base64-encoded + image data. Loads the images into Pillow image objects. + + Args: + messages (List[Dict]): a user request of GPT4V message format + """ + if isinstance(messages, Dict): + messages = [messages] + assert isinstance(messages, List) + + out_messages = [None] * len(messages) + + def _inner_call(i, in_messages, out_messages): + role = in_messages[i]['role'] + content = in_messages[i]['content'] + assert role in ['system', 'user', 'assistant'], \ + f'unsupported role "{role}"' + if role != 'user' or isinstance(content, str): + # the content is a user's prompt or an assistant's prompt, + # returning it directly + out_messages[i] = in_messages[i] + return + # the role is a user and the content is a list, in which there + # might be image_url or image_data + assert isinstance(content, List) + message = dict(role=role, content=[]) + for item in content: + # image url or base64-encoded image data + if item['type'] == 'image_url': + """ + convert the following item: + { + 'type': 'image_url', + 'image_url': { + 'url': 'image url or base64-encoded image data', + 'key': 'value' # parameters used in image processing + ... + } + } + to: + { + 'type': 'image', + 'image': Pillow.Image, + 'key': 'value' # parameters used in image processing + ... + } + """ # noqa + data = item['image_url'].copy() + try: + url = data.pop('url') + image = load_image(url) + data.update(type='image', image=image) + message['content'].append(data) + except KeyError: + logger.error(f'invalid format {message}') + elif item['type'] == 'image_data': + """ + convert the following item: + { + 'type': 'image_data', + 'image_data': { + 'data': Pillow.Image, + 'key': 'value' # parameters used in image processing + ... + } + } + to: + { + 'type': 'image', + 'image': Pillow.Image, + 'key': 'value' # parameters used in image processing + ... + } + """ # noqa + data = item['image_data'].copy() + try: + image = data.pop('data') + data.update(type='image', image=image) + message['content'].append(data) + except KeyError: + logger.error(f'invalid format {message}') + elif item['type'] == 'text': + message['content'].append(item) + else: + logger.error(f'unexpected content type {message}') + out_messages[i] = message + + await asyncio.gather(*[ + asyncio.get_event_loop().run_in_executor(None, _inner_call, i, + messages, out_messages) + for i in range(len(messages)) + ]) + return out_messages + def batch_infer(self, prompts: Union[VLPromptType, List[Dict], List[VLPromptType], List[List[Dict]]], **kwargs): @@ -173,3 +233,46 @@ def chat(self, prompts: VLPromptType, **kwargs): last_round = sess.history[-1] sess.history[-1] = (prompts, last_round[-1]) return sess + + @classmethod + def prompt_to_messages(cls, prompt: VLPromptType): + """convert prompt to GTP4V format.""" + messages = { + 'role': 'user', + 'content': [{ + 'type': 'text', + 'text': '', + }] + } + if isinstance(prompt, str): + messages['content'][0]['text'] = prompt + else: + prompt, images = prompt + if not isinstance(images, list): + images = [images] + messages['content'][0]['text'] = prompt + for image in images: + # 'image_url': means url or local path to image. + # 'image_data': means PIL.Image.Image object. + if isinstance(image, str): + image = load_image(image) + item = { + 'type': 'image_data', + 'image_data': { + 'data': image + } + } + elif isinstance(image, PIL.Image.Image): + item = { + 'type': 'image_data', + 'image_data': { + 'data': image + } + } + else: + raise ValueError( + 'image should be a str(url/path) or PIL.Image.Image') + + messages['content'].append(item) + + return [messages] diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index 77f0bc8dc8..176c3191f4 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -6,7 +6,7 @@ import fire import torch -from lmdeploy.archs import get_model_arch +from lmdeploy.archs import get_model_arch, search_nested_config from lmdeploy.messages import TurbomindEngineConfig from lmdeploy.model import MODELS, best_match_model from lmdeploy.utils import get_logger, get_model @@ -129,16 +129,17 @@ def get_output_model_registered_name_and_config(model_path: str, ] else 'float16' elif dtype in ['float16', 'bfloat16']: if weight_type == 'int4': - logger.warn(f'The model {model_path} is a quantized model, so the ' - f'specified data type {dtype} is ignored') + logger.warning( + f'The model {model_path} is a quantized model, so the ' + f'specified data type {dtype} is ignored') else: weight_type = dtype else: assert 0, f'unsupported specified data type {dtype}' if weight_type == 'bfloat16' and not is_bf16_supported(): - logger.warn('data type fallback to float16 since ' - 'torch.cuda.is_bf16_supported is False') + logger.warning('data type fallback to float16 since ' + 'torch.cuda.is_bf16_supported is False') weight_type = 'float16' config.model_config.model_arch = model_arch config.model_config.weight_type = weight_type @@ -174,23 +175,6 @@ def pack_model_repository(workspace_path: str): dst=osp.join(model_repo_dir, 'postprocessing')) -def find_quantization_config(nested, target_key): - if isinstance(nested, dict): - for key, value in nested.items(): - if key == target_key: - return value - if isinstance(value, (dict, list)): - result = find_quantization_config(value, target_key) - if result is not None: - return result - elif isinstance(nested, list): - for item in nested: - result = find_quantization_config(item, target_key) - if result is not None: - return result - return None - - def get_tm_model(model_path, model_name, chat_template_name, @@ -213,8 +197,7 @@ def get_tm_model(model_path, If it is None, the turbomind model won't be saved """ _, cfg = get_model_arch(model_path) - quant_config = find_quantization_config(cfg.to_dict(), - 'quantization_config') + quant_config = search_nested_config(cfg.to_dict(), 'quantization_config') if quant_config: quant_method = quant_config.get('quant_method') _group_size = int(quant_config.get('group_size', 0)) diff --git a/lmdeploy/turbomind/deploy/module.py b/lmdeploy/turbomind/deploy/module.py index 52497175ef..1754161ff5 100644 --- a/lmdeploy/turbomind/deploy/module.py +++ b/lmdeploy/turbomind/deploy/module.py @@ -191,7 +191,7 @@ def __init__(self, model: BaseOutputModel): self.attn_bias = model.model_config.attn_bias def _reorder_and_merge(self, qkvo): - q, k, v, o = map(transpose, qkvo) + q, k, v, o = qkvo # reorder output dim for tm's rotary embedding layout if self.model.permute_qk: q = permute_v2(q, self.head_dim) @@ -202,6 +202,27 @@ def _reorder_and_merge(self, qkvo): o = torch.zeros_like(q) return qkv, o + def _repeat_kv(self, qkvo, kind: str): + """replicate kv.""" + q, k, v, o = qkvo + head_dim = self.model.model_config.size_per_head + hidden_dim = self.model.model_config.hidden_units + + def _repeat(x): + dim = hidden_dim if kind != 'bias' else 1 + x = x.reshape(dim, -1, head_dim) + x = x.repeat(1, 1, self.model.repeat_kv) + x = x.reshape(dim, -1) + return x + + k, v = map(_repeat, (k, v)) + if kind == 'bias': + if o is None: + o = torch.zeros(hidden_dim, dtype=q.dtype, device=q.device) + q, k, v, o = map(torch.squeeze, (q, k, v, o)) + + return (q, k, v, o) + def _export(self, idx: int, qkvo, kind: str, pack_fn, **kwargs): if all(x is None for x in qkvo): return @@ -209,6 +230,9 @@ def _export(self, idx: int, qkvo, kind: str, pack_fn, **kwargs): if is_lora_a: qkv, o = map(transpose, qkvo) else: + qkvo = tuple(map(transpose, qkvo)) + if self.model.repeat_kv: + qkvo = self._repeat_kv(qkvo, kind) qkv, o = self._reorder_and_merge(qkvo) self.model.save_split(pack_fn(qkv), self._attn.format(idx, 'w_qkv', kind), diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index f2c981bb24..7ea1a84f35 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -78,6 +78,17 @@ def __init__(self, self.model_config.expert_inter_size = _pad_inter_size( self.model_config.expert_inter_size, self.model_config.group_size, self.tensor_para_size) + + # head_num is divisble by tp but kv_head_num is not + # and tp is divisble by kv_head_num + assert self.model_config.head_num % self.tensor_para_size == 0 + self.repeat_kv = 0 + if (self.tensor_para_size > self.model_config.kv_head_num and + self.tensor_para_size % self.model_config.kv_head_num == 0): + self.repeat_kv = (self.tensor_para_size // + self.model_config.kv_head_num) + self.model_config.kv_head_num = self.tensor_para_size + self.model_config.verify() assert self.model_config.kv_head_num % self.tensor_para_size == 0 diff --git a/lmdeploy/turbomind/supported_models.py b/lmdeploy/turbomind/supported_models.py index 11e99edfa0..2b9c5156ed 100644 --- a/lmdeploy/turbomind/supported_models.py +++ b/lmdeploy/turbomind/supported_models.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from lmdeploy.archs import get_model_arch +from lmdeploy.archs import get_model_arch, search_nested_config from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') @@ -80,7 +80,12 @@ def _is_head_dim_supported(cfg): if os.path.exists(triton_model_path): support_by_turbomind = True else: + arch, cfg = get_model_arch(model_path) + quant_method = search_nested_config(cfg.to_dict(), 'quant_method') + if quant_method and quant_method in ['smooth_quant']: + # tm hasn't support quantized models by applying smoothquant + return False if arch in SUPPORTED_ARCHS.keys(): support_by_turbomind = True diff --git a/lmdeploy/version.py b/lmdeploy/version.py index d9f4307a78..f705fcb332 100644 --- a/lmdeploy/version.py +++ b/lmdeploy/version.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple -__version__ = '0.6.3' +__version__ = '0.6.4' short_version = __version__ diff --git a/lmdeploy/vl/engine.py b/lmdeploy/vl/engine.py index 124fd537c6..7f786d5f90 100644 --- a/lmdeploy/vl/engine.py +++ b/lmdeploy/vl/engine.py @@ -1,13 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio -import inspect -import queue -import time -from threading import Thread from typing import Dict, List, Optional, Union import torch -from PIL.Image import Image from lmdeploy.messages import (PytorchEngineConfig, TurbomindEngineConfig, VisionConfig) @@ -27,169 +22,94 @@ def _raise_exception_on_finish(task: asyncio.Task) -> None: raise e -class Record: - """Batching manager.""" - - def __init__(self, thread_safe): - self.thread_safe = thread_safe - self.number = [] - self.waiting = [] - self.kwargs = [] - self.done = [] - self.res_que = [] - self.total = 0 - - def enqueue(self, images: List[Image], kwargs: List[Dict], - que: Union[queue.Queue, asyncio.Queue]): - """add ith request to manager.""" - self.number.append(len(images)) - self.waiting.extend(images) - self.kwargs.extend(kwargs) - self.res_que.append(que) - self.total += len(images) - self.log('received', len(images)) - - def dequeue(self, max_batch_size): - """try to dequeue max batch size images.""" - inputs = self.waiting[:max_batch_size] - kwargs = self.kwargs[:max_batch_size] - self.waiting = self.waiting[max_batch_size:] - self.kwargs = self.kwargs[max_batch_size:] - self.total -= len(inputs) - self.log('process', len(inputs)) - return inputs, kwargs - - def notify(self): - """set result if request i is finished.""" - if len(self.number) == 0 or self.number[0] > len(self.done): - return False - num_images = self.number.pop(0) - outputs = self.done[:num_images] - self.done = self.done[num_images:] - que = self.res_que.pop(0) - self.log('done', num_images) - if self.thread_safe: - que._loop.call_soon_threadsafe(que.put_nowait, outputs) - else: - que.put_nowait(outputs) - return True - - def log(self, task: str, num: int): - logger.info(f'ImageEncoder {task} {num} images, ' - f'left {self.total} images.') - - class ImageEncoder: """Image encoder.""" - def __init__(self, - model_path: str, - vision_config: VisionConfig = None, - backend_config: Optional[Union[TurbomindEngineConfig, - PytorchEngineConfig]] = None): - self.model = load_vl_model(model_path, backend_config=backend_config) + def __init__( + self, + model_path: str, + backend: str, + vision_config: VisionConfig = None, + backend_config: Optional[Union[TurbomindEngineConfig, + PytorchEngineConfig]] = None, + ): + self.model = load_vl_model(model_path, + backend, + backend_config=backend_config) if vision_config is None: vision_config = VisionConfig() self.vision_config = vision_config self.max_batch_size = vision_config.max_batch_size torch.cuda.empty_cache() - self._que: asyncio.Queue = None - self._loop_task: asyncio.Task = None - if vision_config.thread_safe: - self._create_thread_safe_task() - - def _create_thread_safe_task(self): - """thread safe loop task.""" - self._loop = asyncio.new_event_loop() - def _work_thread(): - asyncio.set_event_loop(self._loop) - self._que = asyncio.Queue() - self._loop.run_until_complete(self._forward_loop()) - - thread = Thread(target=_work_thread, daemon=True) - thread.start() - self._loop_thread = thread - - def _create_event_loop_task(self): - """event loop task.""" - task = asyncio.get_event_loop().create_task(self._forward_loop()) - self._loop_task = task - self._loop = task.get_loop() - - @property - def req_que(self): - if self.vision_config.thread_safe: - return self._que - if self._que is None: - self._que = asyncio.Queue() - if self._loop_task is None: - self._create_event_loop_task() - if asyncio.get_event_loop() != self._loop: - raise RuntimeError('Current event loop is different from' - ' the one bound to loop task!') - return self._que - - async def _forward_loop(self): - """working loop to process images.""" - logger.info('start ImageEncoder._forward_loop') - record = Record(self.vision_config.thread_safe) - while True: - while record.total == 0 or (self._que.qsize() and - record.total < self.max_batch_size): - while self._que.qsize() == 0: - await asyncio.sleep(0.01) - item = await self._que.get() - record.enqueue(item[0], item[1], item[2]) - inputs, kwargs = record.dequeue(self.max_batch_size) - future = asyncio.get_event_loop().run_in_executor( - None, self.forward, inputs, kwargs) - future.add_done_callback(_raise_exception_on_finish) - outputs = await future - record.done.extend(outputs) - while record.notify(): - pass - - def _init_input_params(self, - inputs: List[Image], - params: List[Dict] = None): - """Check and init inputs params.""" - if params is None: - params = [{}] * len(inputs) - assert len(params) == len(inputs), \ - 'different length of inputs and kwargs' - return params - - def forward(self, inputs: List[Image], params: List[Dict] = None): - """Model forward.""" - params = self._init_input_params(inputs, params) - time_start = time.perf_counter() - func_params = inspect.signature(self.model.forward).parameters - func_inputs = [inputs, params] if len(func_params) > 1 else [inputs] - outputs = self.model.forward(*func_inputs) - if isinstance(outputs[0], torch.Tensor): - outputs = [x.cpu() for x in outputs] - time_end = time.perf_counter() - logger.info(f'ImageEncoder forward {len(inputs)} images, ' - f'cost {time_end - time_start:.3f}s') + async def preprocess(self, messages: List[Dict]) -> List[Dict]: + """preprocess multimodal data in the messages.""" + future = asyncio.get_event_loop().run_in_executor( + None, self.model.preprocess, messages) + future.add_done_callback(_raise_exception_on_finish) + outputs = await future return outputs - def infer(self, inputs: List[Image], params: List[Dict] = None): - """infer.""" - params = self._init_input_params(inputs, params) - results = self.forward(inputs, params) - return results + async def async_infer(self, messages: List[Dict]) -> List[Dict]: + """get multimodal embedding. + + Args: + messages (List[Dict]): a list of message, which is the output + of `preprocess()` + """ + future = asyncio.get_event_loop().run_in_executor( + None, self.model.forward, messages, self.max_batch_size) + future.add_done_callback(_raise_exception_on_finish) + outputs = await future + return outputs - async def async_infer(self, - inputs: List[Image], - params: List[Dict] = None): - """async infer.""" - params = self._init_input_params(inputs, params) - outputs = asyncio.Queue() - item = (inputs, params, outputs) - if self.vision_config.thread_safe: - self._loop.call_soon_threadsafe(self._que.put_nowait, item) - else: - self.req_que.put_nowait(item) - results = await outputs.get() - return results + async def wrap_for_pytorch(self, messages: List[Dict], chat_template, + tokenizer, sequence_start) -> List[Dict]: + """ + Args: + messages (List[Dict]): a list of message, which is supposed to be + the output of `preprocess` + Returns: + a dict which will be passed to pytorch engine_instance's forward. + The dict is like the following: + Dict( + 'prompt': 'the prompt after applying chat template' + 'input_ids': [], + 'multimodal': { + 'pixel_values': torch.Tensor, + ... + ] + ) + """ + result = self.model.to_pytorch(messages, chat_template, tokenizer, + sequence_start) + # clear data + for i, message in enumerate(messages): + if isinstance(message['content'], List): + messages[i]['preprocess'] = None + return result + + async def wrap_for_turbomind(self, messages: List[Dict], chat_template, + tokenizer, sequence_start) -> Dict: + """ + Args: + messages (List[Dict]): a list of message, which is supposed to be + the output of `async_infer` + Returns: + a dict which will be passed to pytorch engine_instance's forward. + The dict is like the following: + Dict( + 'prompt': 'the prompt after applying chat template' + 'input_ids': [], + 'input_embeddings': list[torch.Tensor], + 'input_embedding_ranges': list[torch.Tensor], + ... + """ + result = self.model.to_turbomind(messages, chat_template, tokenizer, + sequence_start) + # clear data + for i, message in enumerate(messages): + if isinstance(message['content'], List): + messages[i]['preprocess'] = None + messages[i]['forward'] = None + return result diff --git a/lmdeploy/vl/model/base.py b/lmdeploy/vl/model/base.py index 9c5f5f6e6a..0ee22b4688 100644 --- a/lmdeploy/vl/model/base.py +++ b/lmdeploy/vl/model/base.py @@ -2,8 +2,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Union -import PIL -import torch +import numpy as np from mmengine import Registry from transformers import AutoConfig @@ -20,35 +19,227 @@ def __init__(self, model_path: str, with_llm: bool = False, max_memory: Dict[int, int] = None, - hf_config: AutoConfig = None): + hf_config: AutoConfig = None, + backend: str = ''): """init.""" self.model_path = model_path self.with_llm = with_llm self.max_memory = max_memory + self.backend = backend if hf_config is None: _, hf_config = get_model_arch(model_path) self.hf_config = hf_config - self.build_model() @abstractmethod - def build_model(): - """build model.""" + def build_preprocessor(self, ): + """build the preprocessor. + + NOTE: When the derived class implements this method, try not to + introduce the upper stream model repo as a thirdparty package + """ raise NotImplementedError() + def build_model(self, ): + """build the vision part of a VLM model when backend is turbomind. + + But when `with_llm=True`, load the whole VLM model + """ + if self.backend == 'turbomind' or self.with_llm: + raise NotImplementedError() + @abstractmethod + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """preprocess multimodal data in the messages. The derived class, + i.e., a specific vision model, takes the charge of image preprocessing + and the result management. + It can integrate the result into the messages list, or insert it to + the individual image item. + Args: + message(Dict): multimodal data in a dict, which is as follows: + [ + {'role': 'user', 'content': 'user prompt'}, + {'role': 'assisant', 'content': 'AI reponse'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': 'string', + }, + { + 'type': 'image', + 'image': pillow.Image, + 'key1': value1, + ... + }, + { + 'type': 'image', + 'image': pillow.Image, + 'key1': value1, + ... + }, + ... + ] + } + {....} + ] + Returns: + the message list with preprocessing results included, which is + determined by the derived classes + """ # noqa + raise NotImplementedError() + def forward(self, - images: List[PIL.Image.Image], - image_kwargs: List[Dict] = None) -> List[torch.Tensor]: - """extract image feature. + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. Args: - images (List[PIL.Image.Image]): input images - image_kwargs (List[Dict]): input kwargs for each images - + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model Return: - List[torch.Tensor]: extract image feature for each input image + the message list with forwarding results included, which is + determined by the derived classes """ - raise NotImplementedError() + if self.backend == 'turbomind': + raise NotImplementedError() + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + """pack the preprocessing results in a format compatible with what is + required by pytorch engine. ONLY implement it when the backend is + pytorch engine. + + Args: + messages(List[Dict]): the output of `preprocess` + chat_template: the chat template defined in `lmdeploy/model.py` + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + if self.backend == 'pytorch': + raise NotImplementedError() + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + """pack the forwarding results in a format compatible with what is + required by turbomind engine. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the output of `preprocess` + chat_template: the chat template defined in `lmdeploy/model.py` + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + if self.backend == 'turbomind': + raise NotImplementedError() + + @staticmethod + def collect_images(messages): + """gather all images along with their respective parameters from the + messages and compile them into a single list. Each image is converted + to RGB color space. + + Args: + messages (List[Tuple[Image, Dict]]): a list of images with their + corresponding parameters + """ # noqa + images = [] + for message in messages: + content = message['content'] + if not isinstance(content, List): + continue + images.extend([ + (x['image'], + {k: v + for k, v in x.items() if k not in {'type', 'image'}}) + for x in content if x['type'] == 'image' + ]) + return images + + @staticmethod + def to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start): + """auxiliary function to pack the preprocessing results in a format + compatible with what is required by pytorch engine. + + Args: + messages(List[Dict]): the output of `preprocess` + prompt(str): the prompt after applying chat template + IMAGE_TOKEN(str): a placeholder where image tokens will be + inserted + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + # collect all preprocessing result from messages + preps = [x['content'] for x in messages if x['role'] == 'preprocess'] + assert len(preps) == 1 + preps = preps[0] + + # split prompt into segments and validate data + segs = prompt.split(IMAGE_TOKEN) + assert len(segs) == len(preps) + 1, ( + f'the number of {IMAGE_TOKEN} is not equal ' + f'to input images, {len(segs) - 1} vs {len(preps)}') + + # calculate the image token offset for each image + input_ids = [] + for i, seg in enumerate(segs): + if i > 0 and i <= len(preps): + preps[i - 1].update(offset=len(input_ids)) + image_tokens = preps[i - 1]['image_tokens'] + image_token_id = preps[i - 1]['image_token_id'] + input_ids.extend([image_token_id] * image_tokens) + token_ids = tokenizer.encode(seg, + add_bos=((i == 0) and sequence_start)) + input_ids.extend(token_ids) + + return dict(prompt=prompt, input_ids=input_ids, multimodal=preps) + + @staticmethod + def to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start): + """auxiliary function to pack the forwarding results in a format + compatible with what is required by turbomind engine. + + Args: + messages(List[Dict]): the output of `preprocess` + prompt(str): the prompt after applying chat template + IMAGE_TOKEN(str): a placeholder where image tokens will be + inserted + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + # collect image features from messages + features = [x['content'] for x in messages if x['role'] == 'forward'] + features = features[0] + features = [x.cpu().numpy() for x in features] + + # split prompt into segments and validate data + segs = prompt.split(IMAGE_TOKEN) + assert len(segs) == len(features) + 1, ( + f'the number of {IMAGE_TOKEN} is not equal ' + f'to input images, {len(segs) - 1} vs {len(features)}') + + # tokenizer prompt, and get input_embeddings and input_embedding_ranges + input_ids = [] + begins = [] + ends = [] + IMAGE_DUMMY_TOKEN_INDEX = 0 + for i, seg in enumerate(segs): + if i > 0 and i <= len(features): + image_dim = features[i - 1].shape[0] + begins.append(len(input_ids)) + ends.append(begins[-1] + image_dim) + input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim) + seg_ids = tokenizer.encode(seg, + add_bos=((i == 0) and sequence_start)) + input_ids.extend(seg_ids) + ranges = np.stack([begins, ends], axis=1).tolist() + return dict(prompt=prompt, + input_ids=input_ids, + input_embeddings=features, + input_embedding_ranges=ranges) @classmethod def match(cls, config: AutoConfig): diff --git a/lmdeploy/vl/model/builder.py b/lmdeploy/vl/model/builder.py index 2401b42259..00e668c034 100644 --- a/lmdeploy/vl/model/builder.py +++ b/lmdeploy/vl/model/builder.py @@ -2,6 +2,8 @@ import os from typing import Optional, Union +import torch + from lmdeploy.archs import get_model_arch from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig from lmdeploy.utils import get_logger, get_model @@ -29,6 +31,7 @@ def load_vl_model(model_path: str, + backend: str, with_llm: bool = False, backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None): @@ -36,8 +39,9 @@ def load_vl_model(model_path: str, Args: model_path(str): the path or repo_id from model hub of the model - with_llm(bool): whether to remove the LLM part from the model. - When it is False, it means removing LLM part + backend(str): the name of inference backend + with_llm(bool): load LLM model or not. Set it to False for VLM + inference scenarios and True for VLM quantization backend_config: the config of the inference engine """ if not os.path.exists(model_path): @@ -49,7 +53,6 @@ def load_vl_model(model_path: str, max_memory = None if not with_llm: - import torch tp = getattr(backend_config, 'tp', 1) max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(tp)} @@ -57,30 +60,21 @@ def load_vl_model(model_path: str, kwargs = dict(model_path=model_path, with_llm=with_llm, max_memory=max_memory, - hf_config=hf_config) + hf_config=hf_config, + backend=backend) for name, module in VISION_MODELS.module_dict.items(): try: if module.match(hf_config): logger.info(f'matching vision model: {name}') - return module(**kwargs) - except Exception: - logger.error(f'matching vision model: {name} failed') + model = module(**kwargs) + model.build_preprocessor() + # build the vision part of a VLM model when backend is + # turbomind, or load the whole VLM model when `with_llm==True` + if backend == 'turbomind' or with_llm: + model.build_model() + return model + except Exception as e: + logger.error(f'build vision model {name} failed, {e}') raise raise ValueError(f'unsupported vl model with config {hf_config}') - - -def vl_model_with_tokenizer(model_path: str, with_llm: bool = True): - """load visual model.""" - vl_model = load_vl_model(model_path, with_llm).vl_model - llm = vl_model - if hasattr(vl_model, 'language_model'): # deepseek vl - llm = vl_model.language_model - if hasattr(vl_model, 'llm'): # MiniCPMV - llm = vl_model.llm - llm.config.use_cache = False - llm.half().eval() - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path, - trust_remote_code=True) - return vl_model, llm, tokenizer diff --git a/lmdeploy/vl/model/cogvlm.py b/lmdeploy/vl/model/cogvlm.py index ea5a06159e..07d97153f9 100644 --- a/lmdeploy/vl/model/cogvlm.py +++ b/lmdeploy/vl/model/cogvlm.py @@ -1,13 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings -from typing import List - -import torch -from PIL.Image import Image -from transformers import AutoModelForCausalLM +from typing import Dict, List +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel -from lmdeploy.vl.model.utils import disable_logging + +logger = get_logger('lmdeploy') @VISION_MODELS.register_module() @@ -16,7 +13,7 @@ class CogVLMVisionModel(VisonModel): _arch = 'CogVLMForCausalLM' - def build_model(self): + def build_preprocessor(self): from torchvision import transforms self.image_transform = transforms.Compose([ transforms.Resize( @@ -26,57 +23,73 @@ def build_model(self): transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) + image_size = self.hf_config.vision_config['image_size'] + patch_size = self.hf_config.vision_config['patch_size'] + self.n_token_per_image = 2 + (image_size // patch_size // 2)**2 - from accelerate import init_empty_weights, load_checkpoint_and_dispatch - from accelerate.utils import get_balanced_memory, infer_auto_device_map - with init_empty_weights(), warnings.catch_warnings(): - model = AutoModelForCausalLM.from_config(self.hf_config, - trust_remote_code=True) - if not self.with_llm: - del model.lm_head - for key in ['layers', 'norm', 'embed_tokens']: - setattr(model.model, key, None) - else: - self.vl_model = model + def build_model(self): + if self.with_llm: + from transformers import AutoModelForCausalLM + self.vl_model = AutoModelForCausalLM.from_pretrained( + self.model_path, device_map='cpu', trust_remote_code=True) + else: + raise NotImplementedError('turbomind has not supported cogvlm yet') - no_split_module_classes = ['TransformerLayer'] - max_memory = get_balanced_memory( - model, - max_memory=self.max_memory, - dtype=torch.half, - no_split_module_classes=no_split_module_classes) - device_map = infer_auto_device_map( - model, - no_split_module_classes=no_split_module_classes, - max_memory=max_memory, - dtype=torch.half) - same_device_keys = [('model.vision.linear_proj', 'model.vision.boi', - 'model.vision.eoi')] - for keys in same_device_keys: - keys = [k for k in keys if k in device_map] - if len(keys) <= 1: + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to the spec of `super().preprocess`""" + images = self.collect_images(messages) + outputs = [] + for image, _ in images: + image = image.convert('RGB') + pixel_values = self.image_transform(image) + outputs.append( + dict(pixel_values=pixel_values, + image_size=image.size, + image_tokens=self.n_token_per_image, + image_token_id=0)) + messages.append(dict(role='preprocess', content=outputs)) + return messages + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: continue - for k in keys[1:]: - device_map[k] = device_map[keys[0]] + content = [ + x['text'] for x in message['content'] if x['type'] == 'text' + ] + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + + prompt_messages.append( + dict(role='user', content=content[0], num_images=n_images)) - with disable_logging(): - load_checkpoint_and_dispatch( - model=model, - checkpoint=self.model_path, - device_map=device_map if not self.with_llm else {'': 'cpu'}, - no_split_module_classes=no_split_module_classes, - dtype=torch.half) - self.model = model.model.vision - self.model.eval() + from lmdeploy.model import Vicuna + llm_chat_template = Vicuna(eoa=chat_template.eoa, + stop_words=chat_template.stop_words) + prompt = '' + IMAGE_TOKEN = '' + for i, msg in enumerate(prompt_messages): + num_images = msg.pop('num_images', 0) + if num_images == 0: + role = msg['role'] + msg = llm_chat_template.messages2prompt([msg], sequence_start + and i == 0) + msg = dict(role=role, content=msg) + prompt_i = chat_template.messages2prompt([msg], sequence_start + and i == 0) + if num_images > 0: + prompt_i = (IMAGE_TOKEN * num_images) + prompt_i + prompt += prompt_i + return prompt, IMAGE_TOKEN - @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - outputs = [x.convert('RGB') for x in images] - outputs = [self.image_transform(x) for x in outputs] - outputs = torch.stack(outputs, dim=0).to(device='cuda:0', - dtype=torch.half) - outputs = self.model(outputs) - outputs = torch.split(outputs, 1, dim=0) - outputs = [x.squeeze() for x in outputs] - return outputs + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/deepseek.py b/lmdeploy/vl/model/deepseek.py index bfbf03f01e..9780744cf2 100644 --- a/lmdeploy/vl/model/deepseek.py +++ b/lmdeploy/vl/model/deepseek.py @@ -1,15 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. - import warnings -from typing import List +from typing import Dict, List import torch -from PIL.Image import Image from transformers import AutoModelForCausalLM +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import disable_logging +logger = get_logger('lmdeploy') + def check_deepseek_vl_install(): """check deepseek_vl install.""" @@ -18,8 +19,8 @@ def check_deepseek_vl_install(): except ImportError: raise ImportError( 'To use DeepSeekVLModel, please install deepseek_vl by ' - 'pip install git+https://github.com/deepseek-ai/DeepSeek-VL.git' - ' --no-deps') + '`pip install git+https://github.com/deepseek-ai/DeepSeek-VL.git' + ' --no-deps`') @VISION_MODELS.register_module() @@ -28,18 +29,22 @@ class DeepSeekVisionModel(VisonModel): _arch = 'MultiModalityCausalLM' - def build_model(self): + def build_preprocessor(self): check_deepseek_vl_install() - # empty init - from accelerate import init_empty_weights from deepseek_vl.models import VLChatProcessor + self.image_processor = VLChatProcessor.from_pretrained( + self.model_path).image_processor + + def build_model(self): + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" + from accelerate import init_empty_weights with init_empty_weights(): warnings.simplefilter('ignore') model = AutoModelForCausalLM.from_pretrained(self.model_path) + self.vl_model = model if not self.with_llm: del model.language_model - else: - self.vl_model = model from accelerate.utils import get_balanced_memory, infer_auto_device_map max_memory = get_balanced_memory(model, @@ -79,23 +84,111 @@ def build_model(self): device_map=device_map if not self.with_llm else {'': 'cpu'}, dtype=torch.half) + self.model = model.eval() self.vision_model = model.vision_model.eval() self.aligner = model.aligner.eval() - self.image_processor = VLChatProcessor.from_pretrained( - self.model_path).image_processor + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to the spec of `super.preprocess()""" + images = self.collect_images(messages) + outputs = [] + for image, _ in images: + image = image.convert('RGB') + pixel_values = self.image_processor( + [image], return_tensors='pt').pixel_values + outputs.append( + dict( + pixel_values=pixel_values, + image_size=image.size, + # refer to https://github.com/deepseek-ai/DeepSeek-VL/blob/main/deepseek_vl/models/processing_vlm.py # noqa + # which is hardcoded 576 + image_tokens=576, + image_token_id=0)) + messages.append(dict(role='preprocess', content=outputs)) + return messages @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - outputs = [x.convert('RGB') for x in images] - pixel_values = self.image_processor(outputs, - return_tensors='pt').pixel_values - pixel_values = pixel_values.to(device=next( - self.vision_model.parameters()).device, - dtype=torch.float16) - # [b x n_images, T2, D] - images_embeds = self.aligner(self.vision_model(pixel_values)) - - outputs = torch.split(images_embeds, 1, dim=0) - outputs = [x.squeeze() for x in outputs] - return outputs + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(device=next( + self.vision_model.parameters()).device, + dtype=torch.float16) + # [b x n_images, T2, D] + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.aligner(self.vision_model(pixel_values)) + feats = torch.split(feats, 1, dim=0) + outputs.extend([x.squeeze() for x in feats]) + messages.append(dict(role='forward', content=outputs)) + return messages + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + # apply chat template to get the prompt + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: + continue + content = [ + x['text'] for x in message['content'] if x['type'] == 'text' + ] + content = content[0] + n_image = sum( + [1 for x in message['content'] if x['type'] == 'image']) + n_placeholder = content.count(IMAGE_TOKEN) + if n_placeholder == 0: + logger.warning( + f"""for deepseek-vl model, the user should insert the {IMAGE_TOKEN} + to user prompt manually, please read https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html + for more details.""") # noqa + if n_placeholder != 0 and n_placeholder != n_image: + logger.error( + f'unmatched placeholder and image: {n_placeholder} vs ' + f'{n_image}. Ignore the placeholder') + content = content.replace(IMAGE_TOKEN, '') + n_placeholder = 0 + if n_placeholder == 0: + if n_image == 1: + content = f'{IMAGE_TOKEN}{content}' + else: + content = ''.join([ + f'{IMAGE_TOKEN} is Figure {str(i)}.\n' + for i in range(n_image) + ]) + content + prompt_messages.append(dict(role='user', content=content)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/glm_4v.py b/lmdeploy/vl/model/glm_4v.py index 34e060f4c9..813813bf09 100644 --- a/lmdeploy/vl/model/glm_4v.py +++ b/lmdeploy/vl/model/glm_4v.py @@ -1,77 +1,30 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List -import warnings -from typing import List - -import torch -from PIL.Image import Image from transformers import AutoConfig +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel -from lmdeploy.vl.model.utils import disable_logging + +logger = get_logger('lmdeploy') @VISION_MODELS.register_module() class GLM4VisionModel(VisonModel): """glm-4v-9b vision model.""" - _arch = 'ChatGLMModel' + _arch = ['ChatGLMModel', 'ChatGLMForConditionalGeneration'] @classmethod def match(cls, config: AutoConfig): """check whether the config match the model.""" arch = config.architectures[0] - if arch == cls._arch and hasattr(config, 'vision_config'): + if arch in cls._arch and hasattr(config, 'vision_config'): return True return False - def build_model(self): - from accelerate import init_empty_weights, load_checkpoint_and_dispatch - from accelerate.utils import infer_auto_device_map + def build_preprocessor(self): from torchvision import transforms - - with init_empty_weights(), warnings.catch_warnings(): - warnings.simplefilter('ignore') - from transformers import AutoModelForCausalLM - model = AutoModelForCausalLM.from_config(self.hf_config, - trust_remote_code=True) - if not self.with_llm: - del model.transformer.embedding - del model.transformer.rotary_pos_emb - del model.transformer.encoder - del model.transformer.output_layer - else: - self.vl_model = model - - no_split_module_classes = ['TransformerLayer'] - - device_map = infer_auto_device_map( - model, - no_split_module_classes=no_split_module_classes, - max_memory=self.max_memory, - dtype=torch.half) - - same_device_keys = [ - ('transformer.vision.linear_proj', 'transformer.vision.boi', - 'transformer.vision.eoi') - ] - for keys in same_device_keys: - keys = [k for k in keys if k in device_map] - if len(keys) <= 1: - continue - for k in keys[1:]: - device_map[k] = device_map[keys[0]] - - with disable_logging(): - load_checkpoint_and_dispatch( - model=model, - checkpoint=self.model_path, - device_map=device_map if not self.with_llm else {'': 'cpu'}, - no_split_module_classes=no_split_module_classes, - dtype=torch.half) - - model.eval() - self.model = model self.image_transform = transforms.Compose([ transforms.Resize( (self.hf_config.vision_config['image_size'], ) * 2, @@ -80,15 +33,65 @@ def build_model(self): transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) + image_size = self.hf_config.vision_config['image_size'] + patch_size = self.hf_config.vision_config['patch_size'] + self.n_token_per_image = 2 + (image_size // patch_size // 2)**2 + + def build_model(self): + if self.with_llm: + from transformers import AutoModelForCausalLM + self.vl_model = AutoModelForCausalLM.from_pretrained( + self.model_path, device_map='cpu', trust_remote_code=True) + else: + raise NotImplementedError('turbomind has not supported glm4v yet') + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to the spec of `super.preprocess()""" + outputs = [] + for message in messages: + if not isinstance(message['content'], List): + continue + images = [ + x['image'] for x in message['content'] if x['type'] == 'image' + ] + if len(images) > 1: + logger.warning( + f'glm4v does not support the input of multiple images' + f' in a single chat round, but got {len(images)} images.') + # we still pass all the images to the model and let the + # model decide what to do + images = [x.convert('RGB') for x in images] + pixel_values = [self.image_transform(x) for x in images] + outputs.extend([ + dict(pixel_values=_2, + image_size=_1.size, + image_tokens=self.n_token_per_image, + image_token_id=0) for _1, _2 in zip(images, pixel_values) + ]) + messages.append(dict(role='preprocess', content=outputs)) + return messages + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + content = message['content'] + if isinstance(content, str): + prompt_messages.append(message) + continue + elif message['role'] in ['preprocess', 'forward']: + continue + prompt = [x['text'] for x in content if x['type'] == 'text'] + n_images = len([1 for x in content if x['type'] == 'image']) + prompt = ''.join([f'{IMAGE_TOKEN}\n'] * n_images) + prompt[0] + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN - @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - outputs = [x.convert('RGB') for x in images] - outputs = [self.image_transform(x) for x in outputs] - outputs = torch.stack(outputs, dim=0).to(device='cuda:0', - dtype=torch.half) - outputs = self.model.transformer.vision(outputs) - outputs = torch.split(outputs, 1, dim=0) - outputs = [x.squeeze() for x in outputs] - return outputs + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/internvl.py b/lmdeploy/vl/model/internvl.py index fa67192f11..979b8d1a39 100644 --- a/lmdeploy/vl/model/internvl.py +++ b/lmdeploy/vl/model/internvl.py @@ -1,10 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. - from typing import Dict, List import torch -from PIL.Image import Image -from transformers import AutoModel, CLIPImageProcessor +from transformers import AutoConfig, AutoModel, CLIPImageProcessor from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel @@ -80,34 +78,16 @@ class InternVLVisionModel(VisonModel): _arch = 'InternVLChatModel' - def build_model(self): - """Load model.""" - from accelerate import init_empty_weights - with init_empty_weights(): - config = self.hf_config - # transformers below 4.37.0 may raise error about flash_attn - config.llm_config.attn_implementation = 'eager' - model = AutoModel.from_config(config, trust_remote_code=True) - if not self.with_llm: - del model.language_model - else: - self.vl_model = model - model.half() + def __init__(self, + model_path: str, + with_llm: bool = False, + max_memory: Dict[int, int] = None, + hf_config: AutoConfig = None, + backend: str = ''): + super().__init__(model_path, with_llm, max_memory, hf_config, backend) - from accelerate import load_checkpoint_and_dispatch - with disable_logging(): - load_checkpoint_and_dispatch( - model=model, - checkpoint=self.model_path, - device_map='auto' if not self.with_llm else {'': 'cpu'}, - max_memory=self.max_memory, - no_split_module_classes=['InternVisionEncoderLayer'], - dtype=torch.half) - - # We need eval mode to freeze the weights in model, thus, - # avoid randomness in inference. - self.model = model.eval() - self.config = config + def build_preprocessor(self): + self.config = self.hf_config dynamic_image_size = getattr(self.config, 'dynamic_image_size', False) image_processor = None try: @@ -131,62 +111,180 @@ def build_model(self): T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) + self.processor = self._preprocess_v1_5 self._forward_func = self._forward_v1_5 else: + self.processor = self._preprocess self.image_processor = image_processor self._forward_func = self._forward - def _preprocess_v1_5(self, images: List[Image], params: List[Dict] = None): - if params is not None: - assert len(images) == len( - params), 'different length of images and params' - else: - params = [{}] * len(images) + force_image_size = self.hf_config.force_image_size + patch_size = self.hf_config.vision_config.patch_size + downsample_ratio = self.hf_config.downsample_ratio + self.image_tokens_per_patch = int( + (force_image_size // patch_size)**2 * (downsample_ratio**2)) - image_res = {'low': 6, 'medium': 12, 'high': 24} + def build_model(self): + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" + from accelerate import init_empty_weights + with init_empty_weights(): + # transformers below 4.37.0 may raise error about flash_attn + self.config.llm_config.attn_implementation = 'eager' + model = AutoModel.from_config(self.config, trust_remote_code=True) + self.vl_model = model + if not self.with_llm: + del model.language_model - outputs = [] - for image, param in zip(images, params): - max_num = param.get('max_dynamic_patch') - if max_num is None or not isinstance(max_num, int): - res_key = param.get('detail', 'default') - max_num = image_res.get(res_key, self.config.max_dynamic_patch) - out = dynamic_preprocess( - image, - min_num=self.config.min_dynamic_patch, - max_num=max_num, - image_size=self.config.vision_config.image_size, - use_thumbnail=self.config.use_thumbnail) - out = [self.transform(x) for x in out] - out = torch.stack(out) # (patch) x c x h x w - outputs.append(out) - return outputs + model.half() + from accelerate import load_checkpoint_and_dispatch + with disable_logging(): + load_checkpoint_and_dispatch( + model=model, + checkpoint=self.model_path, + device_map='auto' if not self.with_llm else {'': 'cpu'}, + max_memory=self.max_memory, + no_split_module_classes=['InternVisionEncoderLayer'], + dtype=torch.half) + + # We need eval mode to freeze the weights in model, thus, + # avoid randomness in inference. + self.model = model.eval() + + def _preprocess_v1_5(self, image, params=None): + image_res = {'low': 6, 'medium': 12, 'high': 24} + max_num = params.get('max_dynamic_patch') + if max_num is None or not isinstance(max_num, int): + res_key = params.get('detail', 'default') + max_num = image_res.get(res_key, self.config.max_dynamic_patch) + out = dynamic_preprocess( + image, + min_num=self.config.min_dynamic_patch, + max_num=max_num, + image_size=self.config.vision_config.image_size, + use_thumbnail=self.config.use_thumbnail) + pixel_values = [self.transform(x) for x in out] + # (patch) x c x h x w + pixel_values = torch.stack(pixel_values) + return pixel_values - def _forward_v1_5(self, images: List[Image], params: List[Dict] = None): + def _forward_v1_5(self, inputs, max_batch_size): """forward for internvl-chat-v1-5.""" - outputs = self._preprocess_v1_5(images, params) - split = [x.shape[0] for x in outputs] - outputs = torch.cat(outputs, dim=0) - outputs = outputs.to(self.model.device, dtype=torch.float16) - outputs = self.model.extract_feature(outputs) - outputs = torch.split(outputs, split, dim=0) - outputs = [x.reshape(-1, x.shape[-1]) for x in outputs] + assert all(x.get('pixel_values') is not None for x in inputs) + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + split = [x.shape[0] for x in pixel_values] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(self.model.device, + dtype=torch.float16) + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.model.extract_feature(pixel_values) + feats = torch.split(feats, split, dim=0) + outputs.extend([x.reshape(-1, x.shape[-1]) for x in feats]) return outputs - def _forward(self, images: List[Image], params: List[Dict] = None): + def _preprocess(self, image, params=None): """forward for internvl-chat-v1-1, internvl-chat-v1-2.""" - pixel_values = self.image_processor(images=images, + pixel_values = self.image_processor(images=image, return_tensors='pt').pixel_values - pixel_values = pixel_values.to(self.model.device, dtype=torch.float16) - outputs = self.model.extract_feature(pixel_values) - outputs = torch.split(outputs, 1, dim=0) - outputs = [x.squeeze() for x in outputs] + return pixel_values + + def _forward(self, inputs, max_batch_size): + """forward for internvl-chat-v1-1, internvl-chat-v1-2.""" + assert all(x.get('pixel_values') is not None for x in inputs) + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(self.model.device, + dtype=torch.float16) + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.model.extract_feature(pixel_values) + feats = torch.split(feats, 1, dim=0) + outputs.extend([x.squeeze() for x in feats]) return outputs + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to `super.preprocess() for spec.""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + pixel_values = self.processor(image, params) + image_tokens = (pixel_values.shape[0] * + self.image_tokens_per_patch) + outputs.append( + dict(pixel_values=pixel_values, + image_tokens=image_tokens, + image_token_id=0, + image_size=image.size)) + messages.append(dict(role='preprocess', content=outputs)) + return messages + @torch.no_grad() def forward(self, - images: List[Image], - params: List[Dict] = None) -> List[torch.Tensor]: - """forward.""" - images = [x.convert('RGB') for x in images] - return self._forward_func(images, params) + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = self._forward_func(inputs, max_batch_size) + messages.append(dict(role='forward', content=outputs)) + return messages + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['preprocess', 'forward']: + continue + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + x['text'] for x in message['content'] if x['type'] == 'text' + ] + prompt = content[0] + if IMAGE_TOKEN in prompt and f'{IMAGE_TOKEN}' not in prompt: + prompt = prompt.replace(f'{IMAGE_TOKEN}', + f'{IMAGE_TOKEN}') + prompt = prompt.replace('', '') + prompt = prompt.replace('', '') + prompt = prompt.replace('', '') + elif IMAGE_TOKEN not in prompt: + prompt = f'{IMAGE_TOKEN * n_images}\n' + prompt + else: + pass + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/internvl_llava.py b/lmdeploy/vl/model/internvl_llava.py index f607082b18..17a12f71ca 100644 --- a/lmdeploy/vl/model/internvl_llava.py +++ b/lmdeploy/vl/model/internvl_llava.py @@ -2,14 +2,13 @@ import warnings from contextlib import contextmanager -from typing import List, Union +from typing import Dict, List import torch -from PIL.Image import Image from transformers import AutoConfig, AutoModelForCausalLM from lmdeploy.utils import get_logger -from lmdeploy.vl.model.base import VISION_MODELS, VisonModel +from lmdeploy.vl.model.llava import VISION_MODELS, LlavaVisionModel from lmdeploy.vl.model.utils import rewrite_ctx from .utils import disable_logging, disable_transformers_logging @@ -18,14 +17,13 @@ def check_llava_install(): - """check llava install.""" try: from llava.model.multimodal_encoder.clip_encoder import \ InternVisionModel # noqa: F401 except ImportError: raise ImportError( 'To use LlavaVLModel, please install llava by ' - 'pip install "git+https://github.com/OpenGVLab/InternVL#subdirectory=internvl_chat_llava" --no-deps' # noqa: E501 + '`pip install git+https://github.com/OpenGVLab/InternVL#subdirectory=internvl_chat_llava --no-deps`' # noqa: E501 ) @@ -65,7 +63,7 @@ def init_empty_vit(): @VISION_MODELS.register_module() -class InternVLLlavaVisionModel(VisonModel): +class InternVLLlavaVisionModel(LlavaVisionModel): """Llava visual model.""" @classmethod @@ -78,9 +76,12 @@ def match(cls, config: AutoConfig): return True return False + def build_preprocessor(self): + return super().build_preprocessor() + def build_model(self): - """build model & load weights.""" - # check llava install + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" check_llava_install() # currently, only support llava llama from llava.model.language_model.llava_llama import ( # noqa @@ -98,13 +99,12 @@ def build_model(self): } # disable vision part quantization model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True) + self.vl_model = model if not self.with_llm: del model.lm_head del model.model.embed_tokens del model.model.layers del model.model.norm - else: - self.vl_model = model with init_empty_vit(): vision_tower = model.get_vision_tower() @@ -137,42 +137,43 @@ def build_model(self): self.vision_tower = model.model.vision_tower.eval() self.mm_projector = model.model.mm_projector.eval() - def encode_images(self, images: torch.Tensor) -> torch.Tensor: - """encode images.""" - image_features = self.vision_tower(images) - image_features = self.mm_projector(image_features) - return image_features - - def preprocess( - self, - images: List[Image]) -> Union[torch.Tensor, List[torch.Tensor]]: - """preprocess.""" - # TODO: gpu processor - from llava.mm_utils import process_images - images = [x.convert('RGB') for x in images] - image_processor = self.vision_tower.image_processor - outputs = process_images(images, image_processor, self.config) - return outputs + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to `super().preprocess() for spec.""" + return super().preprocess(messages) @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - images = self.preprocess(images) - if isinstance(images, list): - images = [ - x.to(self.vision_tower.device, dtype=torch.float16) - for x in images + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] ] - else: - images = images.to(self.vision_tower.device, dtype=torch.float16) - - if type(images) is list or images.ndim == 5: - concat_images = torch.cat([image for image in images], dim=0) - image_features = self.encode_images(concat_images) - split_sizes = [image.shape[0] for image in images] - image_features = torch.split(image_features, split_sizes, dim=0) - image_features = [x.flatten(0, 1) for x in image_features] - else: - image_features = self.encode_images(images) - image_features = [x for x in image_features] - return image_features + split_sizes = [x.shape[0] for x in pixel_values] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(device=self.vision_tower.device, + dtype=torch.float16) + logger.info(f'vision forward shape: {pixel_values.shape}') + if pixel_values.ndim == 5: + feats = self.encode_images(pixel_values) + feats = torch.split(feats, split_sizes, dim=0) + feats = [x.flatten(0, 1) for x in feats] + else: + feats = self.encode_images(pixel_values) + feats = [x for x in feats] + outputs.extend(feats) + messages.append(dict(role='forward', content=outputs)) + return messages diff --git a/lmdeploy/vl/model/llava.py b/lmdeploy/vl/model/llava.py index 0b18f460cd..7ad919bef7 100644 --- a/lmdeploy/vl/model/llava.py +++ b/lmdeploy/vl/model/llava.py @@ -1,16 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -# Modified from -# https://github.com/haotian-liu/LLaVA.git +# Modified from https://github.com/haotian-liu/LLaVA.git +import ast +import math import warnings from contextlib import contextmanager -from typing import List, Union +from typing import Dict, List import torch -from PIL.Image import Image +from PIL import Image from transformers import AutoConfig, AutoModelForCausalLM from lmdeploy.utils import get_logger -from lmdeploy.vl.model.base import VISION_MODELS, VisonModel +from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel from lmdeploy.vl.model.utils import disable_logging, rewrite_ctx logger = get_logger('lmdeploy') @@ -23,16 +24,14 @@ def check_llava_install(): except ImportError: raise ImportError( 'To use LlavaVLModel, please install llava by ' - 'pip install git+https://github.com/haotian-liu/LLaVA.git --no-deps' # noqa: E501 + '`pip install git+https://github.com/haotian-liu/LLaVA.git --no-deps`' # noqa: E501 ) def _clip_vision_tower_load_model(self, **kwargs): logger.info(f'CLIPVisionTower.load_model: {self.vision_tower_name}') - from transformers import (CLIPImageProcessor, CLIPVisionConfig, - CLIPVisionModel) - self.image_processor = CLIPImageProcessor.from_pretrained( - self.vision_tower_name) + from transformers import CLIPVisionConfig, CLIPVisionModel + config = CLIPVisionConfig.from_pretrained(self.vision_tower_name) self.vision_tower = CLIPVisionModel._from_config(config=config) self.vision_tower.requires_grad_(False) @@ -53,8 +52,166 @@ def init_llava_vision_tower(config): yield +def select_best_resolution(original_size, possible_resolutions): + """Selects the best resolution from a list of possible resolutions based on + the original size. + + Args: + original_size (tuple): The original size of the image in the format (width, height). + possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + + Returns: + tuple: The best fit resolution in the format (width, height). + """ # noqa + original_width, original_height = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float('inf') + + for width, height in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int( + original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, + original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + return best_fit + + +def resize_and_pad_image(image, target_resolution): + """Resize and pad an image to a target resolution while maintaining aspect + ratio. + + Args: + image (PIL.Image.Image): The input image. + target_resolution (tuple): The target resolution (width, height) of the image. + + Returns: + PIL.Image.Image: The resized and padded image. + """ # noqa + original_width, original_height = image.size + target_width, target_height = target_resolution + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + # Resize the image + resized_image = image.resize((new_width, new_height)) + + new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + new_image.paste(resized_image, (paste_x, paste_y)) + + return new_image + + +def divide_to_patches(image, patch_size): + """Divides an image into patches of a specified size. + + Args: + image (PIL.Image.Image): The input image. + patch_size (int): The size of each patch. + + Returns: + list: A list of PIL.Image.Image objects representing the patches. + """ + patches = [] + width, height = image.size + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + box = (j, i, j + patch_size, i + patch_size) + patch = image.crop(box) + patches.append(patch) + + return patches + + +def process_anyres_image(image, processor, grid_pinpoints): + """Process an image with variable resolutions. + + Args: + image (PIL.Image.Image): The input image to be processed. + processor: The image processor object. + grid_pinpoints (str): A string representation of a list of possible resolutions. + + Returns: + torch.Tensor: A tensor containing the processed image patches. + """ # noqa + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + best_resolution = select_best_resolution(image.size, possible_resolutions) + image_padded = resize_and_pad_image(image, best_resolution) + + patches = divide_to_patches(image_padded, processor.crop_size['height']) + + image_original_resize = image.resize( + (processor.size['shortest_edge'], processor.size['shortest_edge'])) + + image_patches = [image_original_resize] + patches + image_patches = [ + processor.preprocess(image_patch, + return_tensors='pt')['pixel_values'][0] + for image_patch in image_patches + ] + return torch.stack(image_patches, dim=0) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def process_images(images, image_processor, model_cfg): + image_aspect_ratio = getattr(model_cfg, 'image_aspect_ratio', None) + new_images = [] + if image_aspect_ratio == 'pad': + for image in images: + image = expand2square( + image, tuple(int(x * 255) for x in image_processor.image_mean)) + image = image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + new_images.append(image) + elif image_aspect_ratio == 'anyres': + for image in images: + image = process_anyres_image(image, image_processor, + model_cfg.image_grid_pinpoints) + new_images.append(image) + else: + return image_processor(images, return_tensors='pt')['pixel_values'] + if all(x.shape == new_images[0].shape for x in new_images): + new_images = torch.stack(new_images, dim=0) + return new_images + + @VISION_MODELS.register_module() -class LlavaVisionModel(VisonModel): +class LlavaVisionModel(LlavaHfVisionModel): """Llava visual model.""" @classmethod @@ -73,9 +230,20 @@ def match(cls, config: AutoConfig): return True return False + def build_preprocessor(self): + from transformers import CLIPImageProcessor + self.image_processor = CLIPImageProcessor.from_pretrained( + self.hf_config.mm_vision_tower) + config = AutoConfig.from_pretrained(self.hf_config.mm_vision_tower) + image_size = config.vision_config.image_size + patch_size = config.vision_config.patch_size + self.n_token_per_image = (image_size // patch_size)**2 + if self.hf_config.mm_vision_select_feature == 'cls_patch': + self.n_token_per_image += 1 + def build_model(self): - """build model & load weights.""" - # check llava install + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" check_llava_install() self.arch = self.hf_config.architectures[0] @@ -104,15 +272,13 @@ def build_model(self): model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True) + self.vl_model = model if not self.with_llm: # remove the LLM part from llava model. - # Instead, Load the LLM part to turbomind engine del model.lm_head del model.model.embed_tokens del model.model.layers del model.model.norm - else: - self.vl_model = model # init empty vision_tower, the embedding layer in CLIPVisionModel # can't init right under init_empty_weights @@ -143,101 +309,113 @@ def encode_images(self, images: torch.Tensor) -> torch.Tensor: image_features = self.mm_projector(image_features) return image_features - def preprocess( - self, - images: List[Image]) -> Union[torch.Tensor, List[torch.Tensor]]: - """preprocess.""" - # TODO: gpu processor - from llava.mm_utils import process_images - images = [x.convert('RGB') for x in images] - image_processor = self.vision_tower.image_processor - outputs = process_images(images, image_processor, self.config) - return outputs + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to `super().preprocess() for spec.""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + pixel_values = process_images([image], self.image_processor, + self.config) + outputs.append( + dict(pixel_values=pixel_values, + image_size=image.size, + image_tokens=self.n_token_per_image, + image_token_id=0)) + messages.append(dict(role='preprocess', content=outputs)) + return messages @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ from llava.model.llava_arch import (get_anyres_image_grid_shape, unpad_image) - image_sizes = [x.size for x in images] - images = self.preprocess(images) - if isinstance(images, list): - images = [ - x.to(device=self.vision_tower.device, dtype=torch.float16) - for x in images + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + image_sizes = [ + x['image_size'] for x in inputs[idx:idx + max_batch_size] ] - else: - images = images.to(device=self.vision_tower.device, - dtype=torch.float16) - if type(images) is list or images.ndim == 5: - if type(images) is list: - images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] - concat_images = torch.cat([image for image in images], dim=0) - image_features = self.encode_images(concat_images) - split_sizes = [image.shape[0] for image in images] - image_features = torch.split(image_features, split_sizes, dim=0) - mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', - 'flat') - image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', - 'square') - if mm_patch_merge_type == 'flat': - image_features = [x.flatten(0, 1) for x in image_features] - elif mm_patch_merge_type.startswith('spatial'): - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - height = width = self.vision_tower.num_patches_per_side - assert height * width == base_image_feature.shape[0] - if image_aspect_ratio == 'anyres': - num_patch_width, num_patch_height = \ - get_anyres_image_grid_shape( - image_sizes[image_idx], - self.config.image_grid_pinpoints, - self.vision_tower.config.image_size) - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, - width, -1) - else: - raise NotImplementedError - if 'unpad' in mm_patch_merge_type: - image_feature = image_feature.permute( - 4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, - 2).flatten( - 2, 3) - image_feature = unpad_image( - image_feature, image_sizes[image_idx]) - image_feature = torch.cat(( - image_feature, - self.model.image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1).to( - image_feature.device)), - dim=-1) - image_feature = image_feature.flatten(1, - 2).transpose( - 0, 1) + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + if pixel_values[0].ndim == 5: + split_sizes = [x.shape[1] for x in pixel_values] + pixel_values = torch.cat([x for x in pixel_values], dim=1) + logger.info(f'vision forward shape: {pixel_values.shape}') + pixel_values = pixel_values.squeeze(0) + pixel_values = pixel_values.to(device=self.vision_tower.device, + dtype=torch.float16) + feats = self.encode_images(pixel_values) + feats = torch.split(feats, split_sizes, dim=0) + mm_patch_merge_type = getattr(self.config, + 'mm_patch_merge_type', 'flat') + image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', + 'square') + if mm_patch_merge_type == 'flat': + outputs.expand([x.flatten(0, 1) for x in feats]) + elif mm_patch_merge_type.startswith('spatial'): + for img_idx, feat in enumerate(feats): + if feat.shape[0] > 1: + base_feat = feat[0] + feat = feat[1:] + height = self.vision_tower.num_patches_per_side + width = self.vision_tower.num_patches_per_side + assert height * width == base_feat.shape[0] + if image_aspect_ratio == 'anyres': + num_patch_width, num_patch_height = \ + get_anyres_image_grid_shape( + image_sizes[img_idx], + self.config.image_grid_pinpoints, + self.vision_tower.config.image_size) + feat = feat.view(num_patch_height, + num_patch_width, height, + width, -1) + else: + raise NotImplementedError + if 'unpad' in mm_patch_merge_type: + feat = feat.permute(4, 0, 2, 1, 3).contiguous() + feat = feat.flatten(1, 2).flatten(2, 3) + feat = unpad_image(feat, image_sizes[img_idx]) + feat = torch.cat( + (feat, self.model. + image_newline[:, None, None].expand( + *feat.shape[:-1], 1).to(feat.device)), + dim=-1) + feat = feat.flatten(1, 2).transpose(0, 1) + else: + feat = feat.permute(0, 2, 1, 3, 4).contiguous() + feat = feat.flatten(0, 3) + feat = torch.cat((base_feat, feat), dim=0) else: - image_feature = image_feature.permute( - 0, 2, 1, 3, 4).contiguous() - image_feature = image_feature.flatten(0, 3) - image_feature = torch.cat( - (base_image_feature, image_feature), dim=0) - else: - image_feature = image_feature[0] - if 'unpad' in mm_patch_merge_type: - image_feature = torch.cat( - (image_feature, - self.model.image_newline[None].to( - image_feature.device)), - dim=0) - new_image_features.append(image_feature) - image_features = new_image_features + feat = feat[0] + if 'unpad' in mm_patch_merge_type: + feat = torch.cat( + (feat, self.model.image_newline[None].to( + feat.device)), + dim=0) + outputs.append(feat) + else: + raise ValueError('Unexpected mm_patch_merge_type: ' + f'{self.config.mm_patch_merge_type}') else: - raise ValueError('Unexpected mm_patch_merge_type: ' - f'{self.config.mm_patch_merge_type}') - else: - image_features = self.encode_images(images) - image_features = [x for x in image_features] - return image_features + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(device=self.vision_tower.device, + dtype=torch.float16) + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.encode_images(pixel_values) + outputs.extend([x for x in feats]) + messages.append(dict(role='forward', content=outputs)) + return messages diff --git a/lmdeploy/vl/model/llava_hf.py b/lmdeploy/vl/model/llava_hf.py index 31be101ae8..c4e3c90bfb 100644 --- a/lmdeploy/vl/model/llava_hf.py +++ b/lmdeploy/vl/model/llava_hf.py @@ -1,15 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. - import warnings -from typing import List +from typing import Dict, List import torch -from PIL.Image import Image from transformers import AutoProcessor +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import disable_logging +logger = get_logger('lmdeploy') + @VISION_MODELS.register_module() class LlavaHfVisionModel(VisonModel): @@ -17,19 +18,31 @@ class LlavaHfVisionModel(VisonModel): _arch = 'LlavaForConditionalGeneration' + def build_preprocessor(self): + processor = AutoProcessor.from_pretrained(self.model_path, + trust_remote_code=True) + if hasattr(processor, 'tokenizer'): + del processor.tokenizer + processor.prtokenizer = None + self.processor = processor.image_processor + image_size = self.hf_config.vision_config.image_size + patch_size = self.hf_config.vision_config.patch_size + self.n_token_per_image = (image_size // patch_size)**2 + if self.hf_config.vision_feature_select_strategy == 'full': + self.n_token_per_image += 1 + def build_model(self): + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" from accelerate import init_empty_weights, load_checkpoint_and_dispatch with init_empty_weights(), warnings.catch_warnings(): warnings.simplefilter('ignore') from transformers import LlavaForConditionalGeneration model = LlavaForConditionalGeneration._from_config(self.hf_config) + self.vl_model = model if not self.with_llm: del model.language_model - for key in ['language_model']: - setattr(model, key, None) - else: - self.vl_model = model # fix for llava-hf/llava-interleave-qwen-7b-hf setattr(model.config, 'tie_word_embeddings', False) @@ -45,35 +58,97 @@ def build_model(self): dtype=torch.half) model.eval() self.model = model - # processor - processor = AutoProcessor.from_pretrained(self.model_path, - trust_remote_code=True) - if hasattr(processor, 'tokenizer'): - del processor.tokenizer - processor.prtokenizer = None - self.processor = processor.image_processor + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to `super.preprocess() for spec.""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + pixel_values = self.processor( + image, return_tensors='pt', + input_data_format='channels_last').pixel_values + outputs.append( + dict(pixel_values=pixel_values, + image_size=image.size, + image_tokens=self.n_token_per_image, + image_token_id=0)) + messages.append(dict(role='preprocess', content=outputs)) + return messages @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - pixel_values = self.processor( - images, return_tensors='pt', - input_data_format='channels_last')['pixel_values'] - pixel_values = pixel_values.to(device=self.model.device, - dtype=self.model.dtype) - image_outputs = self.model.vision_tower.forward( - pixel_values, output_hidden_states=True) - image_features = image_outputs.hidden_states[ - self.hf_config.vision_feature_layer] - if self.hf_config.vision_feature_select_strategy == 'default': - image_features = image_features[:, 1:] - elif self.hf_config.vision_feature_select_strategy == 'full': - image_features = image_features - else: - raise ValueError( - 'Unexpected select feature strategy: ' - f'{self.hf_config.vision_feature_select_strategy}') - image_features = self.model.multi_modal_projector(image_features) - outputs = torch.split(image_features, 1, dim=0) - outputs = [x.squeeze() for x in outputs] - return outputs + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(device=self.model.device, + dtype=self.model.dtype) + logger.info(f'vision forward shape: {pixel_values.shape}') + image_outputs = self.model.vision_tower.forward( + pixel_values, output_hidden_states=True) + image_features = image_outputs.hidden_states[ + self.hf_config.vision_feature_layer] + if self.hf_config.vision_feature_select_strategy == 'default': + image_features = image_features[:, 1:] + elif self.hf_config.vision_feature_select_strategy == 'full': + image_features = image_features + else: + raise ValueError( + 'Unexpected select feature strategy: ' + f'{self.hf_config.vision_feature_select_strategy}') + image_features = self.model.multi_modal_projector(image_features) + image_features = torch.split(image_features, 1, dim=0) + outputs.extend([x.squeeze() for x in image_features]) + messages.append(dict(role='forward', content=outputs)) + return messages + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: + continue + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + item['text'] for item in message['content'] + if item['type'] == 'text' + ] + prompt = (IMAGE_TOKEN + '\n') * n_images + content[0] + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/llava_next.py b/lmdeploy/vl/model/llava_next.py index 9223ebea4f..d355a48d60 100644 --- a/lmdeploy/vl/model/llava_next.py +++ b/lmdeploy/vl/model/llava_next.py @@ -1,46 +1,51 @@ # Copyright (c) OpenMMLab. All rights reserved. - +import itertools import warnings -from typing import List +from typing import Dict, List import torch -from PIL.Image import Image -from transformers import AutoProcessor -from lmdeploy.vl.model.base import VISION_MODELS, VisonModel +from lmdeploy.utils import get_logger +from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel from lmdeploy.vl.model.utils import disable_logging +logger = get_logger('lmdeploy') + @VISION_MODELS.register_module() -class LlavaNextVisionModel(VisonModel): +class LlavaNextVisionModel(LlavaHfVisionModel): """Llava hf vision model.""" _arch = 'LlavaNextForConditionalGeneration' - def build_model(self): - from accelerate import init_empty_weights, load_checkpoint_and_dispatch - from accelerate.utils import get_balanced_memory, infer_auto_device_map - + def build_preprocessor(self): + super().build_preprocessor() + # build the model with empty weights. The model will be used in + # `preprocess` to get the image token number + from accelerate import init_empty_weights with init_empty_weights(), warnings.catch_warnings(): warnings.simplefilter('ignore') from transformers import LlavaNextForConditionalGeneration - model = LlavaNextForConditionalGeneration._from_config( + self.model = LlavaNextForConditionalGeneration._from_config( self.hf_config) + self.vl_model = self.model if not self.with_llm: - del model.language_model - for key in ['language_model']: - setattr(model, key, None) - else: - self.vl_model = model + del self.model.language_model + + def build_model(self): + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" + from accelerate import load_checkpoint_and_dispatch + from accelerate.utils import get_balanced_memory, infer_auto_device_map no_split_module_classes = ['CLIPEncoderLayer'] max_memory = get_balanced_memory( - model, + self.model, max_memory=self.max_memory, dtype=torch.half, no_split_module_classes=no_split_module_classes) device_map = infer_auto_device_map( - model, + self.model, no_split_module_classes=no_split_module_classes, max_memory=max_memory, dtype=torch.half) @@ -55,75 +60,128 @@ def build_model(self): with disable_logging(): load_checkpoint_and_dispatch( - model=model, + model=self.model, checkpoint=self.model_path, device_map=device_map if not self.with_llm else {'': 'cpu'}, no_split_module_classes=no_split_module_classes, dtype=torch.half) - model.eval() - self.model = model - # processor - processor = AutoProcessor.from_pretrained(self.model_path, - trust_remote_code=True) - if hasattr(processor, 'tokenizer'): - del processor.tokenizer - processor.prtokenizer = None - self.processor = processor.image_processor + self.model.eval() - @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to the spec of `super.preprocess()""" from transformers.models.llava_next.modeling_llava_next import \ image_size_to_num_patches - """forward.""" - processed_inputs = self.processor(images, - return_tensors='pt', - input_data_format='channels_last') - pixel_values = processed_inputs['pixel_values'].to( - device=self.model.device, dtype=self.model.dtype) - image_sizes = processed_inputs['image_sizes'].to( - device=self.model.device, dtype=self.model.dtype) - # ! infer image_num_patches from image_sizes - image_num_patches = [ - image_size_to_num_patches( - image_size=imsize, - grid_pinpoints=self.hf_config.image_grid_pinpoints, - patch_size=self.hf_config.vision_config.image_size, - ) for imsize in image_sizes - ] - # figure out if pixel_values is concatenated or stacked - if pixel_values.dim() == 5: - # stacking when input is - # (batch_size, num_patches, num_channels, height, width) - _pixel_values_list = [ - pix_val[:num_patch] - for pix_val, num_patch in zip(pixel_values, image_num_patches) + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + result = self.processor(image, + return_tensors='pt', + input_data_format='channels_last') + # ! infer image_num_patches from image_sizes + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.hf_config.image_grid_pinpoints, + patch_size=self.hf_config.vision_config.image_size, + ) for imsize in result['image_sizes'] ] - pixel_values = torch.cat(_pixel_values_list, dim=0) - elif pixel_values.dim() != 4: - # otherwise has to be stacked from list of - # (num_patches, num_channels, height, width) - raise ValueError(f'pixel_values of shape {pixel_values.shape}, ' - 'expect to be of 4 or 5 dimensions') - image_outputs = self.model.vision_tower.forward( - pixel_values, output_hidden_states=True) - image_features = image_outputs.hidden_states[ - self.hf_config.vision_feature_layer] - if self.hf_config.vision_feature_select_strategy == 'default': - image_features = image_features[:, 1:] - elif self.hf_config.vision_feature_select_strategy == 'full': - image_features = image_features - else: - raise ValueError( - 'Unexpected select feature strategy: ' - f'{self.hf_config.vision_feature_select_strategy}') - image_features = self.model.multi_modal_projector(image_features) - image_features = torch.split(image_features, image_num_patches, dim=0) - image_features, feature_lens = self.model.pack_image_features( - image_features, - image_sizes, - image_newline=self.model.image_newline, - ) - outputs = torch.split(image_features, - feature_lens.cpu().numpy().tolist(), - dim=0) - return outputs + + hidden_size = self.hf_config.text_config.hidden_size + fake_image_features = torch.zeros( + [image_num_patches[0], self.n_token_per_image, hidden_size]) + image_sizes = result['image_sizes'] + image_newline = torch.randn(self.hf_config.text_config.hidden_size) + strategy = self.hf_config.vision_feature_select_strategy + _, image_tokens = self.model.pack_image_features( + [fake_image_features], + image_sizes, + vision_feature_select_strategy=strategy, + image_newline=image_newline) + result.update( + dict(image_size=image.size, + image_patches=image_num_patches, + image_tokens=image_tokens, + image_token_id=0)) + outputs.append(result) + messages.append(dict(role='preprocess', content=outputs)) + return messages + + @torch.no_grad() + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'].to(device=self.model.device, + dtype=self.model.dtype) + for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.cat(pixel_values, dim=0) + image_sizes = [ + x['image_sizes'].to(device=self.model.device, + dtype=self.model.dtype) + for x in inputs[idx:idx + max_batch_size] + ] + image_sizes = torch.cat(image_sizes, dim=0) + image_num_patches = [ + x['num_patch'] for x in inputs[idx:idx + max_batch_size] + ] + image_num_patches = list(itertools.chain(*image_num_patches)) + # figure out if pixel_values is concatenated or stacked + if pixel_values.dim() == 5: + # stacking when input is + # (batch_size, num_patches, num_channels, height, width) + _pixel_values_list = [ + pix_val[:num_patch] for pix_val, num_patch in zip( + pixel_values, image_num_patches) + ] + pixel_values = torch.cat(_pixel_values_list, dim=0) + elif pixel_values.dim() != 4: + # otherwise has to be stacked from list of + # (num_patches, num_channels, height, width) + raise ValueError( + f'pixel_values of shape {pixel_values.shape}, ' + 'expect to be of 4 or 5 dimensions') + logger.info(f'vision forward shape: {pixel_values.shape}') + image_outputs = self.model.vision_tower.forward( + pixel_values, output_hidden_states=True) + image_features = image_outputs.hidden_states[ + self.hf_config.vision_feature_layer] + strategy = self.hf_config.vision_feature_select_strategy + if strategy == 'default': + image_features = image_features[:, 1:] + elif strategy == 'full': + image_features = image_features + else: + raise ValueError('Unexpected select feature strategy: ' + f'{strategy}') + image_features = self.model.multi_modal_projector(image_features) + image_features = torch.split(image_features, + image_num_patches, + dim=0) + image_features, feature_lens = self.model.pack_image_features( + image_features, + image_sizes, + vision_feature_select_strategy=strategy, + image_newline=self.model.image_newline, + ) + image_features = torch.split(image_features, + feature_lens.cpu().numpy().tolist(), + dim=0) + outputs.extend(image_features) + messages.append(dict(role='forward', content=outputs)) + return messages diff --git a/lmdeploy/vl/model/mini_gemeni.py b/lmdeploy/vl/model/mini_gemeni.py index 0565daeba5..eca70aca51 100644 --- a/lmdeploy/vl/model/mini_gemeni.py +++ b/lmdeploy/vl/model/mini_gemeni.py @@ -3,16 +3,18 @@ import os.path as osp import warnings from contextlib import contextmanager -from typing import List +from typing import Dict, List import torch -from PIL.Image import Image +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import (add_device_hook, disable_logging, disable_transformers_logging, hack_import_with) +logger = get_logger('lmdeploy') + def check_mini_gemini_install(): """check mini gemini install.""" @@ -22,8 +24,8 @@ def check_mini_gemini_install(): except ImportError: raise ImportError( 'To use MiniGeminiVisionModel, please install minigemini by ' - 'pip install git+https://github.com/dvlab-research/MGM.git' - ' --no-deps') + '`pip install git+https://github.com/dvlab-research/MGM.git' + ' --no-deps`') def _build_vision_tower(vision_tower_cfg, **kwargs): @@ -169,7 +171,15 @@ class MiniGeminiVisionModel(VisonModel): _arch = ['MiniGeminiLlamaForCausalLM', 'MGMLlamaForCausalLM'] + def build_preprocessor(self): + # pytorch engine will not support mini-gemini. Therefore, in order to + # reuse the previous code as much as possible, we do not extract image + # preprocessor from `build_model` function. + pass + def build_model(self): + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" check_mini_gemini_install() # empty init from accelerate import init_empty_weights @@ -193,13 +203,12 @@ def build_model(self): vision_tower.load_model() vision_tower_aux = model.get_vision_tower_aux() vision_tower_aux.load_model() + self.vl_model = model if not self.with_llm: del model.lm_head del model.model.embed_tokens del model.model.layers del model.model.norm - else: - self.vl_model = model from accelerate.utils import get_balanced_memory, infer_auto_device_map max_memory = get_balanced_memory( @@ -246,11 +255,35 @@ def build_model(self): self.image_processor = image_processor self.process_images = process_images + def preprocess(self, messages: List[Dict]) -> List[Dict]: + return messages + @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - outputs = [x.convert('RGB') for x in images] - image_tensor = self.process_images(outputs, self.image_processor, + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + images = [] + for message in messages: + if not isinstance(message['content'], List): + continue + _ = [ + x['image'] for x in message['content'] if x['type'] == 'image' + ] + assert len(_) == 1, f'MiniGeminiLlama accepts ONE input ' \ + f'image, but got {len(images)} images' + images.extend(_) + + image_tensor = self.process_images(images, self.image_processor, self.model.config) image_grid = getattr(self.model.config, 'image_grid', 1) if hasattr(self.model.config, 'image_size_aux'): @@ -301,15 +334,47 @@ def forward(self, images: List[Image]) -> List[torch.Tensor]: image.to(self.model.device, dtype=torch.float16) for image in image_tensor_aux ] + logger.info(f'vision forward bs: {len(image_tensor)}') else: image_tensor = image_tensor.to(self.model.device, dtype=torch.float16) image_tensor_aux = image_tensor_aux.to(self.model.device, dtype=torch.float16) - + logger.info(f'vision forward shape: {image_tensor.shape}') images_embeds = self.model.encode_images(image_tensor, image_tensor_aux) outputs = torch.split(images_embeds, 1, dim=0) outputs = [x.squeeze() for x in outputs] - return outputs + messages.append(dict(role='forward', cotent=outputs)) + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: + continue + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + item['text'] for item in message['content'] + if item['type'] == 'text' + ] + prompt = (IMAGE_TOKEN + '\n') * n_images + content[0] + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + assert 0, 'cogvlm is not supported by pytorch engine' + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/minicpmv.py b/lmdeploy/vl/model/minicpmv.py index 4e30190c1d..6b0c5f1508 100644 --- a/lmdeploy/vl/model/minicpmv.py +++ b/lmdeploy/vl/model/minicpmv.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +import itertools import warnings from typing import Dict, List import torch from PIL.Image import Image -from transformers import AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel @@ -19,8 +20,33 @@ class MiniCPMVModel(VisonModel): _arch = 'MiniCPMV' + def __init__(self, + model_path: str, + with_llm: bool = False, + max_memory: Dict[int, int] = None, + hf_config: AutoConfig = None, + backend: str = ''): + super().__init__(model_path, with_llm, max_memory, hf_config, backend) + if not hasattr(self.hf_config, 'version'): + raise ValueError('Can not find `version` in config.json. ' + 'Please checkout the latest model') + version = str(self.hf_config.version) + if version not in ['2.5', '2.6']: + raise ValueError( + f'Only support v2.5 and v2.6, but got version {version}') + self.version = version + + def build_preprocessor(self): + from transformers import AutoProcessor + self.processor = AutoProcessor.from_pretrained(self.model_path, + trust_remote_code=True) + self.image_processor = self.processor.image_processor + self._preprocess_func = (self._preprocess_v2_5 if self.version == '2.5' + else self._preprocess_v2_6) + def build_model(self): - """build model & load weights.""" + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" from accelerate import init_empty_weights with init_empty_weights(), warnings.catch_warnings(): warnings.simplefilter('ignore') @@ -29,10 +55,9 @@ def build_model(self): config.quantization_config = {} # disable vision part quantization model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + self.vl_model = model if not self.with_llm: del model.llm - else: - self.vl_model = model from accelerate import load_checkpoint_and_dispatch with disable_logging(): @@ -50,46 +75,11 @@ def build_model(self): device=model.resampler.proj.device) self.config = config self.model = model.eval() - self.init_forward_func() - - def init_forward_func(self): - if not hasattr(self.config, 'version'): - msg = 'LMDeploy only support `MiniCPM-V-2_6` and '\ - '`MiniCPM-Llama3-V-2_5`.\nCan not find `version` in config, ' \ - 'please consider update the huggingface model.' - logger.warn(msg) - - self._forward_func = self._forward_v2_5 - if hasattr(self.config, 'version'): - version = str(self.config.version) - if version == '2.6': - self._forward_func = self._forward_v2_6 - - if self._forward_func == self._forward_v2_5: - logger.info('using _forward_v2_5') - if not hasattr(self.model, 'slice_image'): - # adapt new code commit 287e3f85 (MiniCPM-Llama3-V-2_5) - from transformers import AutoProcessor - processor = AutoProcessor.from_pretrained( - self.model_path, trust_remote_code=True) - self.model.slice_image = processor.image_processor.slice_image - - def _reshape_by_patch(x): - out = x.cpu().numpy() - out = processor.image_processor.reshape_by_patch(out) - return torch.from_numpy(out).to(device=x.device) - - self.model.reshape_by_patch = _reshape_by_patch - - if self._forward_func == self._forward_v2_6: - logger.info('using _forward_v2_6') - from transformers import AutoProcessor - self.model.processor = AutoProcessor.from_pretrained( - self.model_path, trust_remote_code=True) def _get_slice_image(self, image: Image): slice_images = [] - source_image, patches, best_grid = self.model.slice_image(image) + source_image, patches, best_grid = self.image_processor.slice_image( + image) slice_images.append(source_image) if len(patches) > 0: for i in range(len(patches)): @@ -103,114 +93,198 @@ def _reshape_by_patch(self, slice_images): for slice_image in slice_images: slice_image = self.model.transform(slice_image) H, W = slice_image.shape[1:] - patches.append(self.model.reshape_by_patch(slice_image)) + slice_image = slice_image.numpy() + slice_image = self.image_processor.reshape_by_patch(slice_image) + slice_image = torch.from_numpy(slice_image) + patches.append(slice_image) H //= self.config.patch_size W //= self.config.patch_size tgt_sizes.append(torch.Tensor([H, W]).type(torch.int32)) return patches, tgt_sizes - def _forward_v2_5(self, images: List[Image], params: List[Dict] = None): - """forward for MiniCPM-Llama3-V-2_5.""" - patches = [] - tgt_sizes = [] - best_grids = [] - num_patches = [] - for image in images: - slice_images, best_grid = self._get_slice_image(image) - _patches, _tgt_sizes = self._reshape_by_patch(slice_images) - num_patches.append(len(_patches)) - patches.extend(_patches) - tgt_sizes.extend(_tgt_sizes) - best_grids.append(best_grid) - - patches = [ - x.to(dtype=torch.half, device=self.model.device) for x in patches - ] - patches = [x.flatten(end_dim=1).permute(1, 0) for x in patches] - tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) - max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) - all_pixel_values = torch.nn.utils.rnn.pad_sequence(patches, - batch_first=True, - padding_value=0.0) - B, L, _ = all_pixel_values.shape - all_pixel_values = all_pixel_values.permute(0, 2, - 1).reshape(B, 3, -1, L) - patch_attn_mask = torch.zeros((B, 1, max_patches), - dtype=torch.bool, - device=self.model.device) - for i in range(B): - patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True - vision_embedding = self.model.vpm( - all_pixel_values.type(torch.half), - patch_attention_mask=patch_attn_mask).last_hidden_state - vision_embedding = self.model.resampler(vision_embedding, tgt_sizes) - vision_embedding = torch.split(vision_embedding, num_patches, 0) + def _preprocess_v2_5(self, image: Image, params: Dict = None) -> Dict: + """image preprocessing for MiniCPM-Llama3-V-2_5.""" + slice_images, best_grid = self._get_slice_image(image) + # pixel_values, tgt_sizes are list of torch tensors + pixel_values, tgt_sizes = self._reshape_by_patch(slice_images) + num_patches = len(pixel_values) + return dict( + pixel_values=pixel_values, # a list + tgt_sizes=tgt_sizes, # a list + best_grid=best_grid, + num_patches=num_patches, + image_tokens=1, + image_token_id=0) + + def _preprocess_v2_6(self, image: Image, params: Dict = None) -> Dict: + """image preprocessing for MiniCPM-V-2_6.""" + max_slice_nums = self.image_processor.max_slice_nums + use_image_id = self.image_processor.use_image_id + max_slice_nums = params.get('max_slice_nums', max_slice_nums) + use_image_id = params.get('use_image_id', use_image_id) + outputs = self.image_processor(image, max_slice_nums=max_slice_nums) + pixel_values = outputs['pixel_values'][0] + num_patches = len(pixel_values) + pixel_values = [torch.as_tensor(x) for x in pixel_values] + tgt_sizes = outputs['tgt_sizes'][0] + tgt_sizes = [torch.as_tensor(x) for x in tgt_sizes] + grid = self.image_processor.get_sliced_grid( + image_size=image.size, max_slice_nums=max_slice_nums) + return dict( + pixel_values=pixel_values, # a list + tgt_sizes=tgt_sizes, # a list + best_grid=grid, + num_patches=num_patches, + image_tokens=1, + image_token_id=0, + use_image_id=use_image_id) + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to `super().preprocess() for spec.""" outputs = [] - for embeddings, grid in zip(vision_embedding, best_grids): - embeddings = embeddings.cpu() # n x d x h - outputs.append(dict(embeddings=embeddings, grid=grid)) + for i, message in enumerate(messages): + if message['role'] != 'user' or not isinstance( + message['content'], List): + continue + for item in message['content']: + if item['type'] == 'image': + image = item['image'].convert('RGB') + params = { + k: v + for k, v in item.items() if k not in {'type', 'image'} + } + result = self._preprocess_func(image, params) + outputs.append(result) + messages[i].update(dict(preprocess=outputs)) + return messages - return outputs + @torch.no_grad() + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. - def _forward_v2_6(self, images: List[Image], params: List[Dict] = None): - """forward for MiniCPM-V-2_6.""" - patches = [] - tgt_sizes = [] - best_grids = [] - num_patches = [] - max_slice_nums = self.model.processor.image_processor.max_slice_nums - use_image_id = self.model.processor.image_processor.use_image_id - for image, param in zip(images, params): - max_slice_nums = param.get('max_slice_nums', max_slice_nums) - use_image_id = param.get('use_image_id', use_image_id) - outputs = self.model.processor.image_processor( - image, max_slice_nums=max_slice_nums) - patches.extend(outputs['pixel_values'][0]) - num_patches.append(len(outputs['pixel_values'][0])) - tgt_sizes.extend(outputs['tgt_sizes'][0]) - grid = self.model.processor.image_processor.get_sliced_grid( - image_size=image.size, max_slice_nums=max_slice_nums) - best_grids.append(grid) - - patches = [ - torch.as_tensor(x).to(dtype=torch.half, device=self.model.device) - for x in patches + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + # collect preprocess results into a list + inputs = [] + inputs = [ + x['preprocess'] for x in messages if 'preprocess' in x.keys() ] - patches = [x.flatten(end_dim=1).permute(1, 0) for x in patches] - tgt_sizes = [torch.as_tensor(x) for x in tgt_sizes] - tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) - max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) - all_pixel_values = torch.nn.utils.rnn.pad_sequence(patches, + # flatten the list + inputs = list(itertools.chain(*inputs)) + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + tgt_sizes = [ + x['tgt_sizes'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + num_patches = [ + x['num_patches'] for x in inputs[idx:idx + max_batch_size] + ] + # flatten the list + tgt_sizes = list(itertools.chain(*tgt_sizes)) + pixel_values = list(itertools.chain(*pixel_values)) + pixel_values = [ + x.to(dtype=torch.half, device=self.model.device) + for x in pixel_values + ] + pixel_values = [ + x.flatten(end_dim=1).permute(1, 0) for x in pixel_values + ] + pixel_values = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True, padding_value=0.0) - B, L, _ = all_pixel_values.shape - all_pixel_values = all_pixel_values.permute(0, 2, - 1).reshape(B, 3, -1, L) - patch_attn_mask = torch.zeros((B, 1, max_patches), - dtype=torch.bool, - device=self.model.device) - for i in range(B): - patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True - vision_embedding = self.model.vpm( - all_pixel_values.type(torch.half), - patch_attention_mask=patch_attn_mask, - tgt_sizes=tgt_sizes).last_hidden_state - vision_embedding = self.model.resampler(vision_embedding, tgt_sizes) - vision_embedding = torch.split(vision_embedding, num_patches, 0) - outputs = [] - for embeddings, grid in zip(vision_embedding, best_grids): - embeddings = embeddings.cpu() # n x d x h - outputs.append( - dict(embeddings=embeddings, - grid=grid, - use_image_id=use_image_id)) + B, L, _ = pixel_values.shape + pixel_values = pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) + tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) + max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) + patch_attn_mask = torch.zeros((B, 1, max_patches), + dtype=torch.bool, + device=self.model.device) + logger.info(f'vision forward shape: {pixel_values.shape}') + if self.version == '2.5': + for j in range(B): + patch_attn_mask[j, :tgt_sizes[j][0] * + tgt_sizes[j][1]] = True + embeddings = self.model.vpm( + pixel_values.type(torch.half), + patch_attention_mask=patch_attn_mask).last_hidden_state + else: + for j in range(B): + patch_attn_mask[j, 0, :tgt_sizes[j][0] * + tgt_sizes[j][1]] = True + embeddings = self.model.vpm( + pixel_values.type(torch.half), + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes).last_hidden_state - return outputs + embeddings = self.model.resampler(embeddings, tgt_sizes) + embeddings = torch.split(embeddings, num_patches, 0) + for embedding in embeddings: + embedding = embedding.split(1, dim=0) + outputs.extend([x.squeeze() for x in embedding]) + messages.append(dict(role='forward', content=outputs)) + return messages - @torch.no_grad() - def forward(self, - images: List[Image], - params: List[Dict] = None) -> List[torch.Tensor]: - """forward.""" - images = [x.convert('RGB') for x in images] - return self._forward_func(images, params) + def proc_messages(self, messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + idx = 0 + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + if 'preprocess' not in message.keys(): + continue + prompts = [] + for x in message['preprocess']: + prompt = f'{IMAGE_TOKEN}' + if x.get('use_image_id', False): + prompt = f'{idx}' + prompt + idx += 1 + grid = x['best_grid'] + if grid is not None: + if self.version == '2.5': + slice = '\n'.join( + [f'{IMAGE_TOKEN}' * grid[0]] * + grid[1]) + prompt = f'{prompt}{slice}\n' + elif self.version == '2.6': + slice = '\n'.join( + [f'{IMAGE_TOKEN}' * grid[0]] * + grid[1]) + prompt = prompt + slice + prompt += '\n' + else: + prompt = (prompt + + '\n' if self.version == '2.6' else prompt) + prompts.append(prompt) + content = [ + x['text'] for x in message['content'] if x['type'] == 'text' + ] + prompt = ''.join(prompts) + content[0] + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/mllama.py b/lmdeploy/vl/model/mllama.py index db0a0e9cbf..0cae71cd6c 100644 --- a/lmdeploy/vl/model/mllama.py +++ b/lmdeploy/vl/model/mllama.py @@ -2,192 +2,10 @@ from typing import Dict, List -import torch -import torch.nn.functional as F -from PIL.Image import Image -from transformers.modeling_outputs import BaseModelOutput -from transformers.models.mllama.modeling_mllama import MllamaPreTrainedModel - from lmdeploy.vl.model.base import VISION_MODELS, VisonModel -from lmdeploy.vl.model.utils import disable_logging - - -class MllamaVisionModelPatch(MllamaPreTrainedModel): - - def apply_class_embedding(self, - hidden_state: torch.Tensor) -> torch.Tensor: - batch_size, _, hidden_size = hidden_state.shape - class_embedding = self.class_embedding.expand(batch_size, 1, - hidden_size) - class_embedding = class_embedding.to(hidden_state.device) - hidden_state = torch.cat([class_embedding, hidden_state], dim=1) - return hidden_state - - def forward( - self, - pixel_values: torch.Tensor, - aspect_ratio_ids: torch.Tensor, - aspect_ratio_mask: torch.Tensor, - output_attentions: bool = None, - output_hidden_states: bool = None, - return_dict: bool = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # noqa - - batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape # noqa - - pixel_values = pixel_values.reshape( - batch_size * num_concurrent_media * num_tiles, num_channels, - height, width) - aspect_ratio_ids = aspect_ratio_ids.reshape( - batch_size * num_concurrent_media, -1) - - # Patch embedding - patch_embeds = self.patch_embedding( - pixel_values.to(self.dtype).to(self.device)) - hidden_state = patch_embeds.flatten(2).transpose(1, 2) - - # Tile embeddings - _, num_patches, dim = hidden_state.shape - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, -1, dim) - hidden_state = self.pre_tile_positional_embedding( - hidden_state, aspect_ratio_ids) - - # Add cls token - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media * num_tiles, num_patches, dim) - hidden_state = self.apply_class_embedding(hidden_state) - num_patches += 1 - - # Position embeddings - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, num_patches, dim) - hidden_state = self.gated_positional_embedding(hidden_state, - aspect_ratio_ids) - - hidden_state = self.layernorm_pre(hidden_state) - - # Compute the number of tokens to pad - num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 - # Compute padding tuple for pad function - padding = ( - 0, 0, 0, num_padding_patches - ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) - # Pad the tensor - hidden_state = F.pad(hidden_state, padding, mode='constant', value=0) - slice_index = -num_padding_patches if num_padding_patches > 0 else None - - # Prepare attention mask - attention_mask = aspect_ratio_mask.reshape( - batch_size * num_concurrent_media, -1) - from transformers.models.mllama.modeling_mllama import \ - _prepare_aspect_ratio_attention_mask - attention_mask = _prepare_aspect_ratio_attention_mask( - aspect_ratio_mask=attention_mask, - num_patches=self.num_patches, - target_length=hidden_state.shape[2], - dtype=self.dtype, - ) - - # Apply encoder - hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, - dim) - output = self.transformer( - hidden_state, - attention_mask=attention_mask, - output_hidden_states=True, - output_attentions=output_attentions, - ) - hidden_state = output[0] - - hidden_state = self.layernorm_post(hidden_state) - - # Apply global encoder - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, - num_patches + num_padding_patches, - dim) - hidden_state = self.post_tile_positional_embedding( - hidden_state, aspect_ratio_ids) - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, - num_tiles * (num_patches + num_padding_patches), dim) - global_output = self.global_transformer( - hidden_state, - attention_mask=attention_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - ) - hidden_state = global_output[0] - - # Remove padding form hidden state - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, - num_patches + num_padding_patches, - dim) - hidden_state = hidden_state[:, :, :slice_index] - hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, - num_tiles, num_patches, dim) - - # Collect intermediate layer outputs from encoder output - all_intermediate_hidden_states = output[1] - # rewrite to sync device during accelerate pipeline parallel - device = hidden_state.device - all_intermediate_hidden_states = [ - s.to(device) for s in all_intermediate_hidden_states - ] - intermediate_hidden_states = torch.stack( - all_intermediate_hidden_states, dim=-1) - intermediate_hidden_states = intermediate_hidden_states[ - ..., self.intermediate_layers_indices] - - # Remove padding from intermediate hidden states - intermediate_hidden_states = intermediate_hidden_states.reshape( - batch_size * num_concurrent_media, num_tiles, - num_patches + num_padding_patches, -1) - intermediate_hidden_states = intermediate_hidden_states[:, :, : - slice_index] - intermediate_hidden_states = intermediate_hidden_states.reshape( - batch_size, num_concurrent_media, num_tiles, num_patches, -1) - - # Concatenate final hidden state and intermediate hidden states - hidden_state = torch.cat([hidden_state, intermediate_hidden_states], - dim=-1) - - if output_hidden_states: - hidden_states = tuple(all_intermediate_hidden_states) + tuple( - global_output[1]) - else: - hidden_states = None - - if output_attentions: - # global transformer in contrast to `self.transformer` doesn't - # always return hidden states so we might go index out-of-range - global_attn = tuple( - global_output[2]) if output_hidden_states else tuple( - global_output[1]) - attentions = tuple(output[2]) + global_attn - else: - attentions = None - - if not return_dict: - return tuple(v for v in [hidden_state, hidden_states, attentions] - if v is not None) - - return BaseModelOutput( - last_hidden_state=hidden_state, - hidden_states=hidden_states, - attentions=attentions, - ) def check_transformers(): - """check qwen_vl_utils.""" try: from transformers import MllamaForConditionalGeneration # noqa: F401 except ImportError: @@ -202,85 +20,60 @@ class MllamaVLModel(VisonModel): _arch = 'MllamaForConditionalGeneration' - def build_model(self): - check_transformers() - - from transformers.models.mllama.modeling_mllama import \ - MllamaVisionModel - MllamaVisionModel.forward = MllamaVisionModelPatch.forward - MllamaVisionModel.apply_class_embedding = MllamaVisionModelPatch.apply_class_embedding # noqa - from accelerate import init_empty_weights - with init_empty_weights(): - config = self.hf_config - config.quantization_config = {} # disable vision part quantization - # disable accelerate check_tied_parameters_in_config - config.tie_word_embeddings = False - from transformers import MllamaForConditionalGeneration - model = MllamaForConditionalGeneration._from_config(config) - if not self.with_llm: - del model.language_model - else: - self.vl_model = model - - from accelerate import load_checkpoint_and_dispatch - with disable_logging(): - load_checkpoint_and_dispatch( - model=model, - checkpoint=self.model_path, - device_map='auto' if not self.with_llm else {'': 'cpu'}, - max_memory=self.max_memory, - no_split_module_classes=[ - 'MllamaPrecomputedPositionEmbedding', - 'MllamaPrecomputedAspectRatioEmbedding', - 'MllamaVisionEncoderLayer' - ], - dtype=config.torch_dtype) - - self.model = model.eval() - - # processor + def build_preprocessor(self): from transformers import AutoProcessor self.processor = AutoProcessor.from_pretrained(self.model_path) self.image_token_id = 128256 - @torch.no_grad() - def forward(self, - images: List[Image], - params: List[Dict] = None) -> List[torch.Tensor]: - """forward.""" - # only support image input - if params is not None: - assert len(images) == len( - params), 'different length of images and params' + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to the spec of `super().preprocess`""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + results = self.processor.image_processor(images=image, + return_tensors='pt') + results.update(image_size=image.size, + image_tokens=1, + image_token_id=self.image_token_id) + outputs.append(results) + messages.append(dict(role='preprocess', content=outputs)) + return messages + + def build_model(self): + check_transformers() + if self.with_llm: + from transformers import MllamaForConditionalGeneration + model = MllamaForConditionalGeneration.from_pretrained( + self.model_path, device_map='cpu') + self.vl_model = model else: - params = [{}] * len(images) - # resize images with abnormal shape - # TODO try catch image feature extraction in pipeline and - # throw error back to users - for i, image in enumerate(images): - size = image.size - if any([s < 3 for s in size]): - images[i] = image.resize([s * 3 for s in size]) - image_inputs = self.processor.image_processor(images=images, - return_tensors='pt') - pixel_values = image_inputs['pixel_values'].to( - self.model.vision_model.device) - pixel_values = pixel_values.type(self.model.vision_model.dtype) - aspect_ratio_ids = image_inputs['aspect_ratio_ids'].to( - self.model.vision_model.device) - aspect_ratio_mask = image_inputs['aspect_ratio_mask'].to( - self.model.vision_model.device) - vision_outputs = self.model.vision_model( - pixel_values=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - output_hidden_states=False, - output_attentions=False, - return_dict=True) - cross_attention_states = vision_outputs[0] - cross_attention_states = self.model.multi_modal_projector( - cross_attention_states) - _, bsz, _, _, image_token_dim = tuple(cross_attention_states.shape) - cross_attention_states = cross_attention_states.view( - bsz, -1, image_token_dim).split([1] * len(images)) - return cross_attention_states + raise NotImplementedError('turbomind has not supported mllama yet') + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '<|image|>' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: + continue + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + item['text'] for item in message['content'] + if item['type'] == 'text' + ] + prompt = (IMAGE_TOKEN) * n_images + content[0] + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/molmo.py b/lmdeploy/vl/model/molmo.py index 9abae7a309..eccf62ebb6 100644 --- a/lmdeploy/vl/model/molmo.py +++ b/lmdeploy/vl/model/molmo.py @@ -3,11 +3,9 @@ from typing import Dict, List import torch -from PIL.Image import Image from transformers import AutoModelForCausalLM, AutoProcessor from lmdeploy.utils import get_logger -from lmdeploy.vl.constants import IMAGE_TOKEN from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import disable_logging @@ -20,20 +18,26 @@ class MolmoVisionModel(VisonModel): _arch = 'MolmoForCausalLM' + def build_preprocessor(self): + self.processor = AutoProcessor.from_pretrained(self.model_path, + trust_remote_code=True, + torch_dtype=torch.half, + device_map='auto') + def build_model(self): - """Load model.""" + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" from accelerate import init_empty_weights, load_checkpoint_and_dispatch with init_empty_weights(): - config = self.hf_config - model = AutoModelForCausalLM.from_config(config, + model = AutoModelForCausalLM.from_config(self.hf_config, trust_remote_code=True) + + self.vl_model = model if not self.with_llm: # Remove nn modules other than embedding from the LLM model for key in ['emb_drop', 'ln_f', 'blocks', 'ff_out']: del model.model.transformer[key] - self.token_embedding = model.model.transformer.wte - else: - self.vl_model = model + self.token_embedding = model.model.transformer.wte with disable_logging(): load_checkpoint_and_dispatch( @@ -43,118 +47,161 @@ def build_model(self): max_memory=self.max_memory, no_split_module_classes=[ 'ResidualAttentionBlock', 'Embedding' - ]) + ], + dtype=torch.half) # We need eval mode to freeze the weights in model, thus, # avoid randomness in inference. self.model = model.eval() - self.config = config - self.processor = AutoProcessor.from_pretrained(self.model_path, - trust_remote_code=True, - torch_dtype='auto', - device_map='auto') + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to the `super.preprocess() for spec.""" + for i, message in enumerate(messages): + if not isinstance(message['content'], List): + continue + images = [ + x['image'] for x in message['content'] if x['type'] == 'image' + ] + content = [ + x['text'] for x in message['content'] if x['type'] == 'text' + ] + prompt = f' User: {content[0]}' + tokens = self.processor.tokenizer.encode(prompt, + add_special_tokens=False) + # preprocess images. The output is a dict, which is + # { + # 'input_ids': torch.Tensor, + # 'images': torch.Tensor, # (n_patch, d_model) + # 'image_input_idx': torch.Tensor, # (n_patch, d_model) + # 'image_masks': torch.Tensor, # (n_patch, d_model) + # } + result = self.processor.process(images=images, tokens=tokens) + # remove the bos from input_ids which is prepended by molmo's + # processor + input_ids = result['input_ids'][1:] + result.update(input_ids=input_ids) + messages[i].update(preprocess=result) + return messages @torch.no_grad() def forward(self, - images: List[Image], - params: List[Dict] = None) -> List[Dict]: - """forward the model with given input. + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. Args: - images (List): [None] it is not used - params (List): the inputs after precessing GPT4V messages in - `MolmoChatTemplateWrapper`. Its format is like the following: - [[ - {'role': 'user', 'content': 'user prompt'}, - {'role': 'asssistant', 'content': 'assistant prompt'}, - {'role': 'user', 'content': 'user prompt', 'images': [PIL image list]}, - ... - ]] - """ # noqa - - messages = params[0] - assert isinstance(messages, List) - # append an assistant message to `messages` - messages.append(dict(role='assistant', content='')) + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + for i, message in enumerate(messages): + if 'preprocess' not in message.keys(): + continue + inputs = message['preprocess'] + # get input_ids of embedding + inputs = { + k: v.to(self.model.device).unsqueeze(0) + for k, v in inputs.items() + } + input_ids = inputs['input_ids'] + # (batch_size, num_image, num_patch, d_model) + images = inputs['images'] + # (batch_size, num_image, num_patch) + image_input_idx = inputs['image_input_idx'] + image_masks = inputs['image_masks'] + batch_size, seq_len = input_ids.size() + assert batch_size == 1 + input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) + embeddings = self.model.model.transformer.wte(input_ids) + images = images.to(self.model.dtype) + image_masks = image_masks.to(self.model.dtype) + logger.info(f'vision forward shape: {images.shape}') + image_features, _ = self.model.model.vision_backbone( + images, image_masks) + num_image, num_patch = image_features.shape[1:3] + assert image_input_idx.shape == (batch_size, num_image, num_patch) + + # insert the image feature into the embedding. + image_features = image_features.view(batch_size, + num_image * num_patch, -1) + image_input_idx = image_input_idx.view(batch_size, + num_image * num_patch) + valid = image_input_idx >= 0 + batch_idx = torch.arange(batch_size, device=embeddings.device) + batch_idx = torch.tile(batch_idx[:, None], + [1, image_features.shape[1]]) + image_features = image_features.to(embeddings.device) + # Since we remove bos_id from input_ids during `preprocess`, + # the index `image_input_idx[valid]` should be shift to left + # by subtracting 1 + index = image_input_idx[valid] - 1 + embeddings[batch_idx[valid], index] += image_features[valid] + assert embeddings.shape[:2] == (batch_size, seq_len) + messages[i].update( + dict(forward=dict(input_ids=input_ids.flatten(), + embeddings=embeddings))) + return messages + + @staticmethod + def proc_messages(messages): + prompt = [] + IMAGE_TOKEN = '' + for message in messages: + role, content = message['role'], message['content'] + if isinstance(content, List): + n_images = len([1 for x in content if x['type'] == 'image']) + content = [x['text'] for x in content if x['type'] == 'text'] + prompt.append(' User: ' + (IMAGE_TOKEN + '\n') * n_images + + content[0]) + else: + if role == 'user': + prompt.append(f' User: {content}') + elif role == 'assistant': + prompt.append(f' Assistant:{content}') + else: + assert 0, f'molmo does not support role {role}, message is {message}' # noqa + prompt.append(' Assistant:') + return ''.join(prompt) + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + assert 0, 'molmo is not supported by pytorch engine' + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): # results is a list of tuple(input_ids, embeddings) results = [] - # the concat prompt. It is not used during inference but to adhere the - # interface definition of `_get_prompt_input` in `class VLAsyncEngine` - prompts = '' # Prepend BOS # qwen2 and olmo do not have a BOS, and instead use EOS as a generic # separator token. bos = (self.processor.tokenizer.bos_token_id or self.processor.tokenizer.eos_token_id) results.append(([bos], None)) + for i, message in enumerate(messages): - if 'images' in message.keys(): - prompts += ' User: ' + (IMAGE_TOKEN + '\n') * len( - message['images']) + message['content'] - prompt = f' User: {message["content"]}' - tokens = self.processor.tokenizer.encode( - prompt, add_special_tokens=False) - # preprocess images. The output is a dict - inputs = self.processor.process(images=message['images'], - tokens=tokens) - inputs = { - k: v.to(self.model.device).unsqueeze(0) - for k, v in inputs.items() - } - input_ids = inputs['input_ids'] - # remove the bos from input_ids which is prepended by molmo's - # processor - input_ids = input_ids[:, 1:] - images = inputs[ - 'images'] # (batch_size, num_image, num_patch, d_model) - image_input_idx = inputs[ - 'image_input_idx'] # (batch_size, num_image, num_patch) - image_masks = inputs['image_masks'] - batch_size, seq_len = input_ids.size() - assert batch_size == 1 - - # Get embeddings of input. - if input_ids is not None: - input_ids = input_ids * (input_ids != -1).to( - input_ids.dtype) - embeddings = self.model.model.transformer.wte(input_ids) - image_features, _ = self.model.model.vision_backbone( - images, image_masks) - num_image, num_patch = image_features.shape[1:3] - assert image_input_idx.shape == (batch_size, num_image, - num_patch) - - # insert the image feature into the embedding. - image_features = image_features.view(batch_size, - num_image * num_patch, -1) - image_input_idx = image_input_idx.view(batch_size, - num_image * num_patch) - - valid = image_input_idx >= 0 - batch_idx = torch.arange(batch_size, device=embeddings.device) - batch_idx = torch.tile(batch_idx[:, None], - [1, image_features.shape[1]]) - image_features = image_features.to(embeddings.device) - embeddings[batch_idx[valid], - image_input_idx[valid]] += image_features[valid] - assert embeddings.shape[:2] == (batch_size, seq_len) - results.append((input_ids.flatten().tolist(), embeddings)) + prompt = '' + role, content = message['role'], message['content'] + if isinstance(content, List): + forward_result = message.pop('forward') + input_ids = forward_result['input_ids'] + embeddings = forward_result['embeddings'] + results.append((input_ids.tolist(), embeddings)) else: - role = message['role'] - content = message['content'] - assert isinstance(content, str) - prompt = '' if role == 'user': prompt = f' User: {content}' elif role == 'assistant': prompt = f' Assistant:{content}' else: assert 0, f'molmo does not support role {role}, message is {message}' # noqa + if i == len(messages) - 1: + # the last message + assert role == 'user', f'the role of last message is expected to be user, but got {role}' # noqa + prompt += ' Assistant:' + if prompt: input_ids = self.processor.tokenizer.encode( prompt, add_special_tokens=False) results.append((input_ids, None)) - prompts += prompt # concat input_ids from results, calculate the range in the input_ids # where embeddings will be copied to @@ -169,9 +216,9 @@ def forward(self, input_embedding_ranges.append((start, end)) input_ids += _input_ids start += len(_input_ids) - return [ - dict(prompt=prompts, - input_ids=input_ids, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges) - ] + + prompt = self.proc_messages(messages) + return dict(prompt=prompt, + input_ids=input_ids, + input_embeddings=input_embeddings, + input_embedding_ranges=input_embedding_ranges) diff --git a/lmdeploy/vl/model/phi3_vision.py b/lmdeploy/vl/model/phi3_vision.py index 032b8404da..ff00b5d1d9 100644 --- a/lmdeploy/vl/model/phi3_vision.py +++ b/lmdeploy/vl/model/phi3_vision.py @@ -1,198 +1,48 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings -from typing import List +from typing import Dict, List -import torch -from PIL.Image import Image from transformers import AutoProcessor -from lmdeploy.vl.model.base import VISION_MODELS, VisonModel -from lmdeploy.vl.model.utils import disable_logging - - -# from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py # noqa E501 -def _process_image_embedding(self, pixel_values: torch.Tensor, - image_sizes: torch.Tensor): - """process image embedding.""" - img_embeds = pixel_values - img_sizes = image_sizes - target_device = pixel_values.device - target_dtype = pixel_values.dtype - if self.use_hd_transform and img_sizes is not None and len(img_sizes): - assert img_embeds.ndim == 5, f'img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform' # noqa E501 - # img_embeds: (num_images, max_num_crops, 3, H, W) - # img_sizes: (num_images, 2).view(1, -1) - - bs = img_embeds.shape[0] - # Nx(HW)xC - img_features = self.get_img_features(img_embeds.flatten(0, 1)) - base_feat_height = base_feat_width = int(img_features.shape[1]**0.5) - - assert base_feat_height == 24 and base_feat_width == 24, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect 24x24 features for hd transform' # noqa E501 - - # bs x max_num_crops x (24x24) x C - img_features = img_features.view(bs, -1, - base_feat_height * base_feat_width, - self.image_dim_out) - C = self.image_dim_out - H = base_feat_height - - output_imgs = [] - output_len = [] - # training is tensor, inference is list - if isinstance(img_sizes, torch.Tensor): - img_sizes = img_sizes.view(-1, 2) - for _bs in range(bs): - h, w = img_sizes[_bs] - h = h // 336 - w = w // 336 - B_ = h * w - - # 1 x (24x24) x 1024 - global_img_feature = img_features[_bs, :1] - - # 1 x 12 x 12 x 4096 - glb_img = global_img_feature.reshape(1, H, H, C).reshape( - 1, H // 2, 2, H // 2, 2, - C).contiguous().permute(0, 1, 3, 2, 4, - 5).reshape(1, H // 2, H // 2, - 4 * C).contiguous() - temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1) - - # 1 x 156 x 4096 - glb_img = torch.cat([glb_img, temp_glb_GN], - dim=2).reshape(1, -1, 4 * C) - - # (max_num_crops-1) x (12x12) x C - sub_img = img_features[_bs, 1:] - # 16x574x1024 - # get rid of padding sub_img - sub_img = sub_img[:B_] - - # (num_crops, 12, 2, 12, 2, 1024)->(num_crops, 12, 12, 2, 2, 1024) - # -> (num_crops, 12*12, 4*1024) - sub_img = sub_img.reshape(B_, H, H, C).reshape( - B_, H // 2, 2, H // 2, 2, - C).contiguous().permute(0, 1, 3, 2, 4, - 5).reshape(B_, -1, 4 * C).contiguous() - sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute( - 0, 1, 3, 2, 4, 5).reshape(1, h * 12, w * 12, 4 * C) - temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1) - sub_img = torch.cat([sub_img, temp_sub_GN], - dim=2).reshape(1, -1, 4 * C) - # (1, num_img_tokens, 1024*4) - - # glb + sub - if self.hd_transform_order == 'glb_sub': - output_imgs.append( - torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) - elif self.hd_transform_order == 'sub_glb': - output_imgs.append( - torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) - else: - raise NotImplementedError( - f'hd_transform_order = {self.hd_transform_order}' - ) # noqa E501 - - temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12) - assert temp_len == output_imgs[-1].shape[ - 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' # noqa E501 - output_len.append(temp_len) - - img_set_tensor = [] - for _output_img in output_imgs: - img_feature_proj = self.img_projection( - _output_img.to(target_device).to(target_dtype)) - img_set_tensor.append(img_feature_proj) - elif img_embeds.ndim == 4: - tt = (self.get_img_features(img_embeds).to(target_device).to( - target_dtype).reshape(-1, self.image_dim_out)) - img_set_tensor = self.img_projection(tt) # adapted visual features. - elif img_embeds.ndim == 3: - tt = (img_embeds.to(target_device).to(target_dtype).view( - -1, self.image_dim_out)) - img_set_tensor = self.img_projection(tt) # adapted visual features. - else: - raise NotImplementedError - return img_set_tensor +from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel @VISION_MODELS.register_module() -class Phi3VisionModel(VisonModel): - """Llava hf vision model.""" +class Phi3VisionModel(LlavaHfVisionModel): + """Phi3-vision model.""" _arch = 'Phi3VForCausalLM' - def build_model(self): - from accelerate import init_empty_weights, load_checkpoint_and_dispatch - from accelerate.utils import get_balanced_memory, infer_auto_device_map - - with init_empty_weights(), warnings.catch_warnings(): - warnings.simplefilter('ignore') - from transformers import AutoModelForCausalLM - model = AutoModelForCausalLM.from_config(self.hf_config, - trust_remote_code=True) - if not self.with_llm: - del model.lm_head - del model.model.layers - del model.model.norm - del model.model.embed_tokens - del model.model.vision_embed_tokens.wte - else: - self.vl_model = model - - no_split_module_classes = ['CLIPEncoderLayer'] - max_memory = get_balanced_memory( - model, - max_memory=self.max_memory, - dtype=torch.half, - no_split_module_classes=no_split_module_classes) - device_map = infer_auto_device_map( - model, - no_split_module_classes=no_split_module_classes, - max_memory=max_memory, - dtype=torch.half) - same_device_keys = [('model.vision_embed_tokens.img_projection', - 'model.vision_embed_tokens.sub_GN', - 'model.vision_embed_tokens.glb_GN')] - for keys in same_device_keys: - keys = [k for k in keys if k in device_map] - if len(keys) <= 1: - continue - for k in keys[1:]: - device_map[k] = device_map[keys[0]] - - with disable_logging(): - load_checkpoint_and_dispatch( - model=model, - checkpoint=self.model_path, - device_map=device_map if not self.with_llm else {'': 'cpu'}, - no_split_module_classes=no_split_module_classes, - dtype=torch.half) - - model.eval() - self.model = model - # processor + def build_preprocessor(self): processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True) if hasattr(processor, 'tokenizer'): del processor.tokenizer - processor.prtokenizer = None - self.processor = processor.image_processor + processor.tokenizer = None self.processor = processor - @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - process_outputs = self.processor.image_processor( - images, return_tensors='pt').to(device=self.model.device, - dtype=self.model.dtype) - pixel_values = process_outputs['pixel_values'] - image_sizes = process_outputs['image_sizes'] - image_features = _process_image_embedding( - self.model.model.vision_embed_tokens, - pixel_values=pixel_values, - image_sizes=image_sizes) - outputs = [x.squeeze() for x in image_features] - return outputs + def build_model(self): + if self.with_llm: + from transformers import AutoModelForCausalLM + self.vl_model = AutoModelForCausalLM.from_pretrained( + self.model_path, device_map='cpu', trust_remote_code=True) + else: + raise NotImplementedError('turbomind has not supported phi3v yet') + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to `super.preprocess() for spec.""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + result = self.processor.image_processor(image, return_tensors='pt') + h = result['image_sizes'][0][0].item() // 336 + w = result['image_sizes'][0][1].item() // 336 + image_tokens = int((h * w + 1) * 144 + 1 + (h + 1) * 12) + result.update( + dict(image_size=image.size, + image_tokens=image_tokens, + image_token_id=0)) + outputs.append(result) + messages.append(dict(role='preprocess', content=outputs)) + return messages diff --git a/lmdeploy/vl/model/qwen.py b/lmdeploy/vl/model/qwen.py index 3968f27d97..49631ccf35 100644 --- a/lmdeploy/vl/model/qwen.py +++ b/lmdeploy/vl/model/qwen.py @@ -1,14 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List +from typing import Dict, List import torch -from PIL.Image import Image from transformers import AutoModelForCausalLM +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import disable_logging +logger = get_logger('lmdeploy') + @VISION_MODELS.register_module() class QwenVisionModel(VisonModel): @@ -16,19 +18,33 @@ class QwenVisionModel(VisonModel): _arch = 'QWenLMHeadModel' + def build_preprocessor(self): + from torchvision import transforms + from torchvision.transforms import InterpolationMode + mean = (0.48145466, 0.4578275, 0.40821073) + std = (0.26862954, 0.26130258, 0.27577711) + image_size = self.hf_config.visual['image_size'] + self.image_transform = transforms.Compose([ + transforms.Resize((image_size, image_size), + interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + def build_model(self): + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" from accelerate import init_empty_weights with init_empty_weights(): config = self.hf_config config.quantization_config = {} # disable vision part quantization model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + self.vl_model = model if not self.with_llm: del model.lm_head for key in ['wte', 'h', 'ln_f']: setattr(model.transformer, key, None) - else: - self.vl_model = model from accelerate.utils import get_balanced_memory, infer_auto_device_map max_memory = get_balanced_memory( @@ -60,13 +76,86 @@ def build_model(self): self.model = model.transformer.visual.eval() + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refers to `super.preprocess() for spec.""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + pixel_values = self.image_transform(image) + outputs.append( + dict(pixel_values=pixel_values, + image_size=image.size, + image_tokens=256, + image_token_id=0)) + messages.append(dict(role='preprocess', content=outputs)) + return messages + @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - outputs = [x.convert('RGB') for x in images] - outputs = [self.model.image_transform(x) for x in outputs] - outputs = torch.stack(outputs, dim=0) - outputs = self.model(outputs) - outputs = torch.split(outputs, 1, dim=0) - outputs = [x.squeeze() for x in outputs] - return outputs + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.stack(pixel_values, dim=0) + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.model(pixel_values) + feats = torch.split(feats, 1, dim=0) + outputs.extend([x.squeeze() for x in feats]) + messages.append(dict(role='forward', content=outputs)) + return messages + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: + continue + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + x['text'] for x in message['content'] if x['type'] == 'text' + ] + prompt = content[0] + if IMAGE_TOKEN in prompt: + pass + else: + prompt = ''.join([ + f'Picture {str(i)}:{IMAGE_TOKEN}\n' + for i in range(n_images) + ]) + prompt + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/qwen2.py b/lmdeploy/vl/model/qwen2.py index 3eb3c1541c..ed9da332e0 100644 --- a/lmdeploy/vl/model/qwen2.py +++ b/lmdeploy/vl/model/qwen2.py @@ -1,12 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. - from typing import Dict, List import torch -from PIL.Image import Image from lmdeploy.vl.model.base import VISION_MODELS, VisonModel -from lmdeploy.vl.model.utils import disable_logging def check_qwen_vl_deps_install(): @@ -15,7 +12,7 @@ def check_qwen_vl_deps_install(): import qwen_vl_utils # noqa: F401 except ImportError: raise ImportError( - 'please install qwen_vl_utils by pip install qwen_vl_utils' # noqa: E501 + 'please install qwen_vl_utils by `pip install qwen_vl_utils`' # noqa: E501 ) try: from transformers import Qwen2VLForConditionalGeneration # noqa: F401 @@ -31,85 +28,105 @@ class Qwen2VLModel(VisonModel): _arch = 'Qwen2VLForConditionalGeneration' + def build_preprocessor(self): + check_qwen_vl_deps_install() + from transformers import AutoProcessor + self.processor = AutoProcessor.from_pretrained(self.model_path) + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to `super().preprocess()` for spec.""" + from qwen_vl_utils import process_vision_info + + images = self.collect_images(messages) + optional_keys = { + 'resized_height', 'resized_width', 'min_pixels', 'max_pixels' + } + outputs = [] + for image, params in images: + image = image.convert('RGB') + + item = dict(type='image', image=image) + item.update({ + key: params[key] + for key in params.keys() if key in optional_keys + }) + image_inputs, _ = process_vision_info([dict(content=[item])]) + result = self.processor.image_processor(images=image_inputs, + videos=None, + return_tensors='pt') + merge_length = self.processor.image_processor.merge_size**2 + image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length + result.update( + dict(image_size=image.size, + image_tokens=image_tokens, + image_token_id=0)) + outputs.append(result) + messages.append(dict(role='preprocess', content=outputs)) + return messages + def build_model(self): check_qwen_vl_deps_install() from transformers import Qwen2VLForConditionalGeneration if self.with_llm: - model = Qwen2VLForConditionalGeneration.from_pretrained( - self.hf_config._name_or_path, trust_remote_code=True) - model.half() - self.vl_model = model + self.vl_model = Qwen2VLForConditionalGeneration.from_pretrained( + self.model_path, device_map='cpu') else: - from accelerate import init_empty_weights - with init_empty_weights(): - config = self.hf_config - config.quantization_config = { - } # disable vision part quantization - # disable accelerate check_tied_parameters_in_config - # for Qwen2-VL-2B-Instruct - config.tie_word_embeddings = False - - model = Qwen2VLForConditionalGeneration._from_config(config) - del model.model - del model.lm_head - model.half() - from accelerate import load_checkpoint_and_dispatch - with disable_logging(): - load_checkpoint_and_dispatch( - model=model, - checkpoint=self.model_path, - device_map='auto' if not self.with_llm else {'': 'cpu'}, - max_memory=self.max_memory, - no_split_module_classes=['Qwen2VLVisionBlock'], - dtype=torch.half) - - self.model = model.eval() - - # processor - from transformers import AutoProcessor - self.processor = AutoProcessor.from_pretrained(self.model_path) + raise NotImplementedError( + 'turbomind has not supported qwen2-vl yet') @torch.no_grad() def forward(self, - images: List[Image], - params: List[Dict] = None) -> List[torch.Tensor]: - """forward.""" - # only support image input - if params is not None: - assert len(images) == len( - params), 'different length of images and params' - else: - params = [{}] * len(images) + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. - from qwen_vl_utils import process_vision_info - images = [x.convert('RGB') for x in images] - content = [] - optional_keys = [ - 'resized_height', 'resized_width', 'min_pixels', 'max_pixels' - ] - for image, param in zip(images, params): - item = dict(type='image', image=image) - item.update({k: param[k] for k in optional_keys if k in param}) - content.append(item) - messages = [dict(content=content)] - image_inputs, _ = process_vision_info(messages) - image_inputs = self.processor.image_processor(images=image_inputs, - videos=None, - return_tensors='pt') - pixel_values = image_inputs['pixel_values'].to( - self.model.visual.get_device()) - image_grid_thw = image_inputs['image_grid_thw'].to( - self.model.visual.get_device()) - pixel_values = pixel_values.type(self.model.visual.get_dtype()) - image_embeds = self.model.visual(pixel_values, - grid_thw=image_grid_thw).cpu() - merge_length = self.processor.image_processor.merge_size**2 - split_size = image_inputs['image_grid_thw'].prod(dim=1) // merge_length - image_embeds = image_embeds.split(split_size.tolist()) + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + assert 0, 'TODO: support turbomind engine' - outputs = [] - for i, embeddings in enumerate(image_embeds): - outputs.append( - dict(embeddings=embeddings, - grid_thw=image_inputs['image_grid_thw'][i].tolist())) - return outputs + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: + continue + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + item['text'] for item in message['content'] + if item['type'] == 'text' + ] + prompt = content[0] + if IMAGE_TOKEN in prompt and '<|vision_start|>' not in prompt: + prompt = prompt.replace( + IMAGE_TOKEN, + f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>') + else: + # Qwen2-VL-2B-Instruct will concat image and user prompt + # according to their order in the content list + # we insert image token before user prompt by default. The + # user can use custom image token position if they want the + # same decorated prompt as Qwen2-VL + prompt = f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>' * \ + n_images + prompt + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + """return to the information needed by pytorch engine.""" + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/xcomposer2.py b/lmdeploy/vl/model/xcomposer2.py index 96bc900c02..3c72d0c29f 100644 --- a/lmdeploy/vl/model/xcomposer2.py +++ b/lmdeploy/vl/model/xcomposer2.py @@ -5,7 +5,7 @@ import sys import warnings from contextlib import contextmanager -from typing import Any, List, Tuple +from typing import Any, Dict, List, Tuple import torch from PIL.Image import Image @@ -19,6 +19,17 @@ logger = get_logger('lmdeploy') +def check_xcomposer_install(): + try: + # WARNING! we have to do this otherwise the model_type is wrong for + # xcomposer2d5 + import decord # noqa: F401 + except ImportError: + raise ImportError( + "No module named 'decord'. Please install decord by `pip install decord`" # noqa + ) + + class ModelType(enum.Enum): """Request type.""" XCOMPOSER2 = enum.auto() @@ -83,6 +94,17 @@ def init_empty_vit(model_path): class Xcomposer2VisionModel(VisonModel): """InternLM-Xcomposer2 vision model.""" + def __init__(self, + model_path: str, + with_llm: bool = False, + max_memory: Dict[int, int] = None, + hf_config: AutoConfig = None, + backend: str = ''): + super().__init__(model_path, with_llm, max_memory, hf_config, backend) + check_xcomposer_install() + self.model_type, self.module = get_xcomposer_type(self.model_path) + logger.info(f'matching type of {self.model_type}') + @classmethod def match(cls, config: AutoConfig): """check whether the config match the model.""" @@ -94,7 +116,37 @@ def match(cls, config: AutoConfig): return True return False + def build_preprocessor(self): + + import torchvision.transforms as transforms + from torchvision.transforms.functional import InterpolationMode + + if self.model_type in [ + ModelType.XCOMPOSER2D5, ModelType.XCOMPOSER2_4KHD + ]: + self.HD_transform = self.module + self.vis_processor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + self.preprocess_func = (self._preprocess_2d5 if self.model_type + == ModelType.XCOMPOSER2D5 else + self._preprocess_4khd_7b) + else: + self.vis_processor = transforms.Compose([ + transforms.Resize( + (self.hf_config.img_size, self.hf_config.img_size), + interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + self.preprocess_func = self._preprocess_7b + def build_model(self): + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" from accelerate import init_empty_weights with init_empty_weights(), warnings.catch_warnings(), \ init_empty_vit(self.model_path): @@ -106,23 +158,10 @@ def build_model(self): model.vit.resize_pos() model.vit.vision_tower.vision_model.post_layernorm.to_empty( device='cpu').half() + self.vl_model = model if not self.with_llm: del model.model del model.output - else: - self.vl_model = model - - # additional components. - model_type, module = get_xcomposer_type(self.model_path) - logger.info(f'matching type of {model_type}') - if model_type == ModelType.XCOMPOSER2D5: - self.HD_transform = module - self._forward_func = self._forward_2d5 - elif model_type == ModelType.XCOMPOSER2_4KHD: - self.HD_transform = module - self._forward_func = self._forward_4khd_7b - else: - self._forward_func = self._forward_7b from accelerate.utils import get_balanced_memory, infer_auto_device_map max_memory = get_balanced_memory( @@ -156,51 +195,117 @@ def build_model(self): self.model = model.eval() - def _forward_2d5(self, images: List[Image]) -> List[torch.Tensor]: - """internlm-xcomposer2d5-7b vit forward.""" - outputs = [x.convert('RGB') for x in images] - hd_num = 6 if len(images) > 1 else 24 - outputs = [self.HD_transform(x, hd_num=hd_num) for x in outputs] - outputs = [ - self.model.vis_processor(x).unsqueeze(0).to(dtype=torch.half) - for x in outputs - ] - embeds, split = self.model.vit(outputs, self.model.plora_glb_GN, - self.model.plora_sub_GN) - embeds = self.model.vision_proj(embeds) - embeds = torch.split(embeds, split, dim=1) - embeds = [x.squeeze() for x in embeds] - return embeds - - def _forward_7b(self, images: List[Image]) -> List[torch.Tensor]: - """internlm-xcomposer2-7b vit forward.""" - outputs = [x.convert('RGB') for x in images] - outputs = [ - self.model.vis_processor(x).unsqueeze(0).half() for x in outputs - ] - outputs = torch.cat(outputs, dim=0) - outputs = self.model.vit(outputs) - outputs = self.model.vision_proj(outputs) - outputs = torch.split(outputs, 1, dim=0) - outputs = [x.squeeze() for x in outputs] - return outputs - - def _forward_4khd_7b(self, images: List[Image]) -> List[torch.Tensor]: - """internlm-xcomposer2-4khd-7b vit forward.""" - outputs = [x.convert('RGB') for x in images] - outputs = [self.HD_transform(x, hd_num=25) for x in outputs] - outputs = [ - self.model.vis_processor(x).unsqueeze(0).to(dtype=torch.half) - for x in outputs - ] - embeds, split = self.model.vit(outputs, self.model.plora_glb_GN, - self.model.plora_sub_GN) - embeds = self.model.vision_proj(embeds) - embeds = torch.split(embeds, split, dim=1) - embeds = [x.squeeze() for x in embeds] - return embeds + def _preprocess_2d5(self, image: Image, params: Dict) -> Dict: + """image preprocessing for internlm-xcomposer2d5-7b.""" + hd_num = params.get('hd_num', 24) + image = self.HD_transform(image, hd_num=hd_num) + pixel_values = self.vis_processor(image).unsqueeze(0).half() + w, h = image.size + n_token_per_image = int((h * w + 1) * 400 + 1 + (h + 1) * 20) + return pixel_values, n_token_per_image + + def _preprocess_7b(self, image: Image, params: Dict) -> Dict: + """image preprocessing for internlm-xcomposer2-7b.""" + pixel_values = self.vis_processor(image).unsqueeze(0).half() + return pixel_values, 256 + + def _preprocess_4khd_7b(self, image: Image, params: Dict) -> Dict: + """image preprocessing for internlm-xcomposer2-4khd-7b.""" + image = self.HD_transform(image, hd_num=25) + pixel_values = self.vis_processor(image).unsqueeze(0).half() + w, h = image.size + n_token_per_image = int((h * w + 1) * 144 + 1 + (h + 1) * 12) + return pixel_values, n_token_per_image + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to `super().preprocess() for spec.""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + pixel_values, n_token = self.preprocess_func(image, params) + outputs.append( + dict(pixel_values=pixel_values, + image_size=image.size, + image_tokens=n_token, + image_token_id=0)) + messages.append(dict(role='preprocess', content=outputs)) + return messages @torch.no_grad() - def forward(self, images: List[Image]) -> List[torch.Tensor]: - """forward.""" - return self._forward_func(images) + def forward(self, + messages: List[Dict], + max_batch_size: int = 1) -> List[Dict]: + """extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'] + inputs = inputs[0] + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + if self.model_type in [ + ModelType.XCOMPOSER2D5, ModelType.XCOMPOSER2_4KHD + ]: + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + embeds, split = self.model.vit(pixel_values, + self.model.plora_glb_GN, + self.model.plora_sub_GN) + embeds = self.model.vision_proj(embeds) + embeds = torch.split(embeds, split, dim=1) + embeds = [x.squeeze() for x in embeds] + else: + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = torch.cat(pixel_values, dim=0) + logger.info(f'vision forward shape: {pixel_values.shape}') + embeds = self.model.vit(pixel_values) + embeds = self.model.vision_proj(embeds) + embeds = torch.split(embeds, 1, dim=0) + embeds = [x.squeeze() for x in embeds] + outputs.extend(embeds) + messages.append(dict(role='forward', content=outputs)) + return messages + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: + continue + n_images = len( + [1 for x in message['content'] if x['type'] == 'image']) + content = [ + item['text'] for item in message['content'] + if item['type'] == 'text' + ] + prompt = ' '.join([IMAGE_TOKEN] * n_images) + content[0] + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, + sequence_start) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, + sequence_start) diff --git a/lmdeploy/vl/model/yi.py b/lmdeploy/vl/model/yi.py index 34b993322e..f8d3a907ff 100644 --- a/lmdeploy/vl/model/yi.py +++ b/lmdeploy/vl/model/yi.py @@ -1,12 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import os from contextlib import contextmanager +from os import path as osp +from typing import Dict, List import torch.nn as nn from transformers import AutoConfig from lmdeploy.vl.model.base import VISION_MODELS -from lmdeploy.vl.model.llava import LlavaVisionModel, check_llava_install +from lmdeploy.vl.model.llava import (LlavaVisionModel, check_llava_install, + process_images) from .utils import disable_transformers_logging, rewrite_ctx @@ -96,8 +99,22 @@ def match(cls, config: AutoConfig): return True return False + def build_preprocessor(self): + from transformers import CLIPImageProcessor + vision_tower_name = osp.join(self.model_path, + self.hf_config.mm_vision_tower) + self.image_processor = CLIPImageProcessor.from_pretrained( + vision_tower_name) + config = AutoConfig.from_pretrained(vision_tower_name) + image_size = config.image_size + patch_size = config.patch_size + self.n_token_per_image = (image_size // patch_size)**2 + if self.hf_config.mm_vision_select_feature == 'cls_patch': + self.n_token_per_image += 1 + def build_model(self): - """build model & load weights.""" + """build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" check_llava_install() global _model_path @@ -105,3 +122,19 @@ def build_model(self): with init_yi_model(), disable_transformers_logging(): super().build_model() + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """refer to `super().preprocess() for spec.""" + images = self.collect_images(messages) + outputs = [] + for image, params in images: + image = image.convert('RGB') + pixel_values = process_images([image], self.image_processor, + self.config) + outputs.append( + dict(pixel_values=pixel_values, + image_size=image.size, + image_tokens=self.n_token_per_image, + image_token_id=0)) + messages.append(dict(role='preprocess', content=outputs)) + return messages diff --git a/lmdeploy/vl/templates.py b/lmdeploy/vl/templates.py deleted file mode 100644 index cdf398868a..0000000000 --- a/lmdeploy/vl/templates.py +++ /dev/null @@ -1,550 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import asyncio -from typing import Dict, List, Tuple, Union - -import PIL -import PIL.Image - -from lmdeploy.archs import get_model_arch -from lmdeploy.model import BaseModel -from lmdeploy.utils import get_logger -from lmdeploy.vl.constants import IMAGE_TOKEN -from lmdeploy.vl.utils import load_image - -logger = get_logger('lmdeploy') - -VLPromptType = Union[str, Tuple[str, PIL.Image.Image], - Tuple[str, List[PIL.Image.Image]]] - - -class VLChatTemplateWrapper: - """vl chat template wrapper.""" - - def __init__(self, chat_template: BaseModel): - self.chat_template = chat_template - - def prompt_to_messages(self, prompt: VLPromptType): - """convert prompt to GTP4V format.""" - messages = { - 'role': 'user', - 'content': [{ - 'type': 'text', - 'text': '', - }] - } - if isinstance(prompt, str): - messages['content'][0]['text'] = prompt - else: - prompt, images = prompt - if not isinstance(images, list): - images = [images] - messages['content'][0]['text'] = prompt - for image in images: - # 'image_url': means url or local path to image. - # 'image_data': means PIL.Image.Image object. - if isinstance(image, str): - image = load_image(image) - item = { - 'type': 'image_data', - 'image_data': { - 'data': image - } - } - elif isinstance(image, PIL.Image.Image): - item = { - 'type': 'image_data', - 'image_data': { - 'data': image - } - } - else: - raise ValueError( - 'image should be a str(url/path) or PIL.Image.Image') - - messages['content'].append(item) - - return [messages] - - async def async_collect_pil_images( - self, messages: Dict) -> List[Tuple[PIL.Image.Image, Dict]]: - """collect image from messages.""" - images_with_kwargs = [] - for message in messages: - role = message['role'] - content = message['content'] - if role != 'user' or isinstance(content, str): - continue - for item in content: - # 'image_url': means url or local path to image. - # 'image_data': means PIL.Image.Image object. - if item['type'] == 'image_url': - item_copy = item['image_url'].copy() - try: - url = item_copy.pop('url') - images_with_kwargs.append([url, item_copy]) - except KeyError: - logger.error(f'invalid format {message}') - elif item['type'] == 'image_data': - item_copy = item['image_data'].copy() - try: - data = item_copy.pop('data') - images_with_kwargs.append([data, item_copy]) - except KeyError: - logger.error(f'invalid format {message}') - - def _inner_call(i, images): - url_or_data = images[i][0] - images[i][0] = load_image(url_or_data) - - await asyncio.gather(*[ - asyncio.get_event_loop().run_in_executor(None, _inner_call, i, - images_with_kwargs) - for i in range(len(images_with_kwargs)) - ]) - - return images_with_kwargs - - def append_image_token(self, prompt, num_images: int): - """append image token to user prompt.""" - if IMAGE_TOKEN in prompt: - return prompt - return (IMAGE_TOKEN + '\n') * num_images + prompt - - def convert_messages(self, messages, sequence_start=True): - """convert GPT4V message format to GPT4 text format.""" - new_messages = [] - for message in messages: - role = message['role'] - content = message['content'] - if role != 'user' or isinstance(content, str): - if isinstance(content, list): - text = content[0]['text'] - message = {'role': role, 'content': text} - new_messages.append(message) - continue - num_images = 0 - for item in content: - # 'image_url': means url or local path to image. - # 'image_data': means PIL.Image.Image object. - if item['type'] == 'image_url': - num_images += 1 - elif item['type'] == 'image_data': - num_images += 1 - elif item['type'] == 'text': - prompt = item['text'] - if num_images > 0: - # add IMAGE_TOKEN to user prompt - prompt = self.append_image_token(prompt, num_images) - new_item = {'role': 'user', 'content': prompt} - new_messages.append(new_item) - return new_messages - - def messages2prompt(self, messages, sequence_start=True, **kwargs) -> str: - """convert messages to decorated prompt.""" - if isinstance(messages, str): - return self.chat_template.messages2prompt(messages, sequence_start) - new_messages = self.convert_messages(messages, sequence_start) - return self.chat_template.messages2prompt(new_messages, sequence_start) - - -class LlavaVLChatTemplateWrapper(VLChatTemplateWrapper): - """Llava vl chat template.""" - pass - - -class YiVLChatTemplateWrapper(VLChatTemplateWrapper): - """Yi vl chat template.""" - pass - - -class InternVLChatTemplateWrapper(VLChatTemplateWrapper): - """InternVL chat template.""" - - def append_image_token(self, prompt, num_images: int): - """append image tokens to user prompt.""" - # lmdeploy uses as image token - # internvl uses special tags - if IMAGE_TOKEN in prompt and f'{IMAGE_TOKEN}' not in prompt: - prompt = prompt.replace(f'{IMAGE_TOKEN}', - f'{IMAGE_TOKEN}') - prompt = prompt.replace('', '') - prompt = prompt.replace('', '') - prompt = prompt.replace('', '') - elif IMAGE_TOKEN not in prompt: - prompt = f'{IMAGE_TOKEN * num_images}\n' + prompt - return prompt - - -class DeepSeekVLChatTemplateWrapper(VLChatTemplateWrapper): - """DeepSeek vl chat template.""" - - def append_image_token(self, prompt, num_images: int): - """append image tokens to user prompt.""" - if IMAGE_TOKEN in prompt: - return prompt - logger.error( - f'for deepseek-vl model, the user should insert the {IMAGE_TOKEN} ' - 'to user prompt manually, please read https://lmdeploy.readthedocs' - '.io/en/latest/inference/vl_pipeline.html for more details.') - if num_images == 1: - return f'{IMAGE_TOKEN}{prompt}' - res = '' - for i in range(num_images): - res += f'{IMAGE_TOKEN} is Figure {str(i)}.\n' - res = res + prompt - return res - - -class QwenVLChatTemplateWrapper(VLChatTemplateWrapper): - """Qwen vl chat template.""" - - def append_image_token(self, prompt, num_images: int): - """append image tokens to user prompt.""" - if IMAGE_TOKEN in prompt: - return prompt - res = '' - for i in range(num_images): - res += f'Picture {str(i)}:{IMAGE_TOKEN}\n' - res = res + prompt - return res - - -class Qwen2VLChatTemplateWrapper(VLChatTemplateWrapper): - """qwen2 vl.""" - - def append_image_token(self, prompt, num_images: int): - """append image tokens to user prompt.""" - if IMAGE_TOKEN in prompt and '<|vision_start|>' not in prompt: - prompt = prompt.replace( - IMAGE_TOKEN, f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>') - else: - # Qwen2-VL-2B-Instruct will concat image and user prompt according - # to their order in the content list - # we insert image token before user prompt by default. The user can - # use custom image token position if they want the same decorated - # prompt as Qwen2-VL - prompt = f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>' * \ - num_images + prompt - return prompt - - def get_mrope_info(self, - seq_len: int, - grid_thws: List[Tuple[int, int, int]] = None, - embedding_ranges: List[Tuple[int, int]] = None): - import torch - if grid_thws is None: - mrope_position_ids = torch.arange(seq_len).expand(3, -1) - mrope_position_delta = torch.tensor([0], dtype=torch.long) - else: - mrope_position_ids = [ - torch.arange(embedding_ranges[0][0]).expand(3, -1) - ] - st_idx = embedding_ranges[0][0] - for i, (grid_thw, embedding_range) in enumerate( - zip(grid_thws, embedding_ranges)): - llm_grid_t, llm_grid_h, llm_grid_w = grid_thw - llm_grid_h //= 2 - llm_grid_w //= 2 - t_index = torch.arange(llm_grid_t).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( - llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( - llm_grid_t, llm_grid_h, -1).flatten() - mrope_position_ids.append( - torch.stack([t_index, h_index, w_index]) + st_idx) - st_idx += max(llm_grid_h, llm_grid_w) - if i < len(embedding_ranges) - 1: - text_len = embedding_ranges[i + - 1][0] - embedding_ranges[i][1] - else: - text_len = seq_len - embedding_range[1] - mrope_position_ids.append( - torch.arange(text_len).expand(3, -1) + st_idx) - st_idx += text_len - mrope_position_ids = torch.cat(mrope_position_ids, dim=-1) - mrope_position_delta = torch.tensor([st_idx - seq_len], - dtype=torch.long) - - return mrope_position_ids, mrope_position_delta - - -class CogVLMChatTemplateWrapper(VLChatTemplateWrapper): - """cogvlm chat template wrapper.""" - - def __init__(self, chat_template: BaseModel): - from lmdeploy.model import Vicuna - self.chat_template = chat_template - self.llm_chat_template = Vicuna(eoa=chat_template.eoa, - stop_words=chat_template.stop_words) - - def convert_messages(self, messages, sequence_start=True): - """convert GPT4V message format to GPT4 text format.""" - new_messages = [] - for message in messages: - role = message['role'] - content = message['content'] - if role != 'user' or isinstance(content, str): - new_messages.append(message) - continue - num_images = 0 - for item in content: - if item['type'] == 'image_url': - num_images += 1 - elif item['type'] == 'image_data': - num_images += 1 - elif item['type'] == 'text': - prompt = item['text'] - - new_item = { - 'role': 'user', - 'content': prompt, - 'num_images': num_images - } - new_messages.append(new_item) - return new_messages - - def messages2prompt(self, messages, sequence_start=True, **kwargs) -> str: - """convert messages to decorated prompt.""" - if isinstance(messages, str): - return self.chat_template.messages2prompt(messages, sequence_start) - new_messages = self.convert_messages(messages, sequence_start) - prompt = '' - for i, msg in enumerate(new_messages): - num_images = msg.pop('num_images', 0) - if num_images == 0: - role = msg['role'] - msg = self.llm_chat_template.messages2prompt([msg], - sequence_start - and i == 0) - msg = dict(role=role, content=msg) - prompt_i = self.chat_template.messages2prompt([msg], sequence_start - and i == 0) - if num_images > 0: - prompt_i = (IMAGE_TOKEN * num_images) + prompt_i - prompt += prompt_i - return prompt - - -class InternLMXComposer2TemplateWrapper(VLChatTemplateWrapper): - """InternLM-XComposer2 chat template.""" - - def append_image_token(self, prompt, num_images: int): - if IMAGE_TOKEN in prompt: - return prompt - logger.warning(f'auto append {IMAGE_TOKEN} at the beginning, ' - 'the user can manually insert the token to prompt') - return ' '.join([IMAGE_TOKEN] * num_images) + prompt - - -class MiniGeminiLlamaTempateWrapper(VLChatTemplateWrapper): - """Qwen vl chat template.""" - - def append_image_token(self, prompt, num_images: int): - """append image tokens to user prompt.""" - if num_images == 0: - return prompt - if IMAGE_TOKEN in prompt: - return prompt - res = f'{IMAGE_TOKEN}\n' - assert num_images <= 1, 'MiniGeminiLlama accepts 1 input image' - res = res + prompt - return res - - -class MllamaTempateWrapper(VLChatTemplateWrapper): - """Mllama chat template.""" - - def append_image_token(self, prompt, num_images: int): - """append image tokens to user prompt.""" - return f'{IMAGE_TOKEN * num_images}{prompt}' - - -class MiniCPMVTempateWrapper(VLChatTemplateWrapper): - """MiniCPM-Llama3-V-2_5 chat template.""" - - def append_image_token(self, prompt, num_images: int): - if IMAGE_TOKEN in prompt: - return prompt - prompt = f'{IMAGE_TOKEN}\n' * num_images + prompt - return prompt - - def update_image_token(self, prompt, features): - _features = [] - _prompt = [] - segs = prompt.split(f'{IMAGE_TOKEN}\n') - for i, seg in enumerate(segs): - if i > 0 and i <= len(features): - _feat = features[i - 1]['embeddings'].split(1) - _feat = [x.squeeze() for x in _feat] - _features.extend(_feat) - _seg = f'{IMAGE_TOKEN}' - if len(_feat) > 1: - grid = features[i - 1]['grid'] - if grid is not None: - _slice = '\n'.join( - [f'{IMAGE_TOKEN}' * grid[0]] * - grid[1]) - _seg = f'{_seg}{_slice}\n' - _prompt.append(_seg) - _prompt.append(seg) - _prompt = ''.join(_prompt) - return _prompt, _features - - -class MiniCPMV26TempateWrapper(MiniCPMVTempateWrapper): - """MiniCPM-V-2_6 chat template.""" - - def update_image_token(self, prompt, features): - _features = [] - _prompt = [] - segs = prompt.split(f'{IMAGE_TOKEN}\n') - idx = 0 - for i, seg in enumerate(segs): - if i > 0 and i <= len(features): - _feat = features[i - 1]['embeddings'].split(1) - _feat = [x.squeeze() for x in _feat] - _features.extend(_feat) - _seg = f'{IMAGE_TOKEN}' - if features[i - 1].get('use_image_id', False): - _seg = f'{idx}' + _seg - idx += 1 - if len(_feat) > 1: - grid = features[i - 1]['grid'] - if grid is not None: - _slice = '\n'.join( - [f'{IMAGE_TOKEN}' * grid[0]] * - grid[1]) - _seg = _seg + _slice - _seg += '\n' - _prompt.append(_seg) - _prompt.append(seg) - _prompt = ''.join(_prompt) - return _prompt, _features - - -class GLM4VChatTemplateWrapper(VLChatTemplateWrapper): - """glm-4v chat template.""" - pass - - -class MolmoChatTemplateWrapper(VLChatTemplateWrapper): - - async def async_collect_pil_images( - self, messages: List[Dict]) -> List[Tuple[PIL.Image.Image, Dict]]: - """collect images from messages. - - Args: - messages (List[Dict]): a user request of GPT4V message format - """ - if isinstance(messages, Dict): - messages = [messages] - assert isinstance(messages, List) - - out_messages = [None] * len(messages) - - def _inner_call(i, in_messages, out_messages): - role = in_messages[i]['role'] - content = in_messages[i]['content'] - if role != 'user' or isinstance(content, str): - # means message is user's prompt input or assistant's prompt, - # returning it directory - out_messages[i] = in_messages[i] - return - # the role is a user and the content is a list - assert isinstance(content, List) - message = dict(role=role, content='', images=[]) - for item in content: - # 'image_url': means url or local path to image. - # 'image_data': means PIL.Image.Image object. - if item['type'] == 'image_url': - try: - image = load_image(item['image_url']['url']) - message['images'].append(image) - except KeyError: - logger.error(f'invalid format {message}') - elif item['type'] == 'image_data': - try: - image = load_image(item['image_data']['data']) - message['images'].append(image) - except KeyError: - logger.error(f'invalid format {message}') - elif item['type'] == 'text': - message['content'] = item['text'] - else: - logger.error(f'unexpected content type {message}') - out_messages[i] = message - - await asyncio.gather(*[ - asyncio.get_event_loop().run_in_executor(None, _inner_call, i, - messages, out_messages) - for i in range(len(messages)) - ]) - return [(None, out_messages)] - - def messages2prompt(self, messages, sequence_start=True, **kwargs) -> str: - """Return a placeholder "IMAGE_TOKEN" so that - `vl_asyn_engine._get_prompt_input` can know that it.""" - if isinstance(messages, str): - return self.chat_template.messages2prompt(messages, sequence_start) - else: - _messages = [] - for message in messages: - role, content = message['role'], message['content'] - if role != 'user' or isinstance(content, str): - _messages.append(message) - continue - for item in content: - item_type = item['type'] - if item_type in ['image_url', 'image_data']: - # Return the image placeholder so that - # `vl_asyn_engine._get_prompt_input` can know that the - # request contains images - return IMAGE_TOKEN - _messages.append(dict(role=role, content=item[item_type])) - return self.chat_template.messages2prompt(_messages, - sequence_start) - - -def get_vl_prompt_template(model_path: str, chat_template: BaseModel, - model_name: str) -> VLChatTemplateWrapper: - """get vision language prompt template.""" - assert type(chat_template) != type(BaseModel()), 'failed to match ' \ - 'chat template, please explicit set chat_template_config' # noqa E721 - if model_name == 'yi-vl': - return YiVLChatTemplateWrapper(chat_template) - arch, cfg = get_model_arch(model_path) - if arch == 'QWenLMHeadModel': - return QwenVLChatTemplateWrapper(chat_template) - elif arch in [ - 'LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM', - 'LlavaForConditionalGeneration', - 'LlavaNextForConditionalGeneration', 'Phi3VForCausalLM' - ]: - return LlavaVLChatTemplateWrapper(chat_template) - elif arch == 'MultiModalityCausalLM': # deepseek-vl - return DeepSeekVLChatTemplateWrapper(chat_template) - elif arch == 'MllamaForConditionalGeneration': # llama 3.2 - return MllamaTempateWrapper(chat_template) - elif arch == 'CogVLMForCausalLM': - return CogVLMChatTemplateWrapper(chat_template) - elif arch in ['InternLMXComposer2ForCausalLM', 'InternLM2ForCausalLM']: - return InternLMXComposer2TemplateWrapper(chat_template) - elif arch == 'InternVLChatModel': - return InternVLChatTemplateWrapper(chat_template) - elif arch in ['MiniGeminiLlamaForCausalLM', 'MGMLlamaForCausalLM']: - return MiniGeminiLlamaTempateWrapper(chat_template) - elif arch == 'MiniCPMV': - version_map = { - '2.5': MiniCPMVTempateWrapper, - '2.6': MiniCPMV26TempateWrapper - } - version = str(getattr(cfg, 'version', '2.5')) - return version_map[version](chat_template) - elif arch == 'ChatGLMModel': - return GLM4VChatTemplateWrapper(chat_template) - elif arch == 'Qwen2VLForConditionalGeneration': - return Qwen2VLChatTemplateWrapper(chat_template) - elif arch == 'MolmoForCausalLM': - return MolmoChatTemplateWrapper(chat_template) - raise ValueError(f'unsupported vl_prompt_template with arch {arch}') diff --git a/requirements/runtime_ascend.txt b/requirements/runtime_ascend.txt index 05d74bbe72..81f538275c 100644 --- a/requirements/runtime_ascend.txt +++ b/requirements/runtime_ascend.txt @@ -1,5 +1,5 @@ accelerate>=0.29.3 -dlinfer-ascend>=0.1.2 +dlinfer-ascend>=0.1.3 einops fastapi fire @@ -16,7 +16,8 @@ safetensors sentencepiece shortuuid tiktoken -torch<=2.4.0,>=2.0.0 -torchvision<=0.19.0,>=0.15.0 +torch<=2.4.0,>=2.3.1 +torch-npu==2.3.1 +torchvision<=0.19.0,>=0.18.1 transformers uvicorn diff --git a/requirements/runtime_cuda.txt b/requirements/runtime_cuda.txt new file mode 100644 index 0000000000..41af6039bd --- /dev/null +++ b/requirements/runtime_cuda.txt @@ -0,0 +1,22 @@ +accelerate>=0.29.3 +einops +fastapi +fire +mmengine-lite +numpy<2.0.0 +openai +outlines<0.1.0 +peft<=0.11.1 +pillow +protobuf +pydantic>2.0.0 +pynvml +safetensors +sentencepiece +shortuuid +tiktoken +torch<=2.5.1,>=2.0.0 +torchvision<=0.20.1,>=0.15.0 +transformers +triton==3.0.0; sys_platform == "linux" +uvicorn diff --git a/requirements/runtime.txt b/requirements/runtime_maca.txt similarity index 78% rename from requirements/runtime.txt rename to requirements/runtime_maca.txt index 400c492b09..f65b3827cd 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime_maca.txt @@ -1,4 +1,4 @@ -accelerate>=0.29.3 +accelerate==0.32.1 einops fastapi fire @@ -18,5 +18,5 @@ tiktoken torch<=2.4.0,>=2.0.0 torchvision<=0.19.0,>=0.15.0 transformers -triton>=2.2.0,<=3.0.0; sys_platform == "linux" +triton>=2.1.0; sys_platform == "linux" uvicorn diff --git a/requirements.txt b/requirements_cuda.txt similarity index 70% rename from requirements.txt rename to requirements_cuda.txt index 91d38808f1..7c1d387dfb 100644 --- a/requirements.txt +++ b/requirements_cuda.txt @@ -1,4 +1,4 @@ -r requirements/build.txt --r requirements/runtime.txt +-r requirements/runtime_cuda.txt -r requirements/lite.txt -r requirements/serve.txt diff --git a/requirements_maca.txt b/requirements_maca.txt new file mode 100644 index 0000000000..075b132c8c --- /dev/null +++ b/requirements_maca.txt @@ -0,0 +1,4 @@ +-r requirements/build.txt +-r requirements/runtime_maca.txt +-r requirements/lite.txt +-r requirements/serve.txt diff --git a/setup.py b/setup.py index 7a08ac7919..52e180d8a2 100644 --- a/setup.py +++ b/setup.py @@ -4,18 +4,14 @@ from setuptools import find_packages, setup -npu_available = False -try: - import torch_npu - - npu_available = torch_npu.npu.is_available() -except ImportError: - pass - pwd = os.path.dirname(__file__) version_file = 'lmdeploy/version.py' +def get_target_device(): + return os.getenv('LMDEPLOY_TARGET_DEVICE', 'cuda') + + def readme(): with open(os.path.join(pwd, 'README.md'), encoding='utf-8') as f: content = f.read() @@ -154,16 +150,12 @@ def gen_packages_items(): setup_requires=parse_requirements('requirements/build.txt'), tests_require=parse_requirements('requirements/test.txt'), install_requires=parse_requirements( - 'requirements/runtime_ascend.txt' - if npu_available else 'requirements/runtime.txt'), + f'requirements/runtime_{get_target_device()}.txt'), extras_require={ 'all': - parse_requirements('requirements_ascend.txt' - if npu_available else 'requirements.txt'), - 'lite': - parse_requirements('requirements/lite.txt'), - 'serve': - parse_requirements('requirements/serve.txt') + parse_requirements(f'requirements_{get_target_device()}.txt'), + 'lite': parse_requirements('requirements/lite.txt'), + 'serve': parse_requirements('requirements/serve.txt') }, has_ext_modules=check_ext_modules, classifiers=[ diff --git a/tests/pytorch/kernel/test_flash_attention.py b/tests/pytorch/kernel/test_flash_attention.py index 7d4b7a7f3a..e56de44b37 100644 --- a/tests/pytorch/kernel/test_flash_attention.py +++ b/tests/pytorch/kernel/test_flash_attention.py @@ -10,20 +10,26 @@ def _conti_input(data, q_seqlens): return data -def _make_bias(q_seqlens, history_lens, neg_val): - full_seq_lens = q_seqlens + history_lens +def _make_bias(q_seqlens, history_lens, neg_val, causal): + kv_seqlens = q_seqlens + history_lens max_seq_len = q_seqlens.max().item() - max_full_len = full_seq_lens.max().item() - seq_ranges = [torch.arange(max_seq_len) for _ in q_seqlens] - for r, l in zip(seq_ranges, q_seqlens): - r[l:] = -max_full_len - seq_ranges = torch.stack(seq_ranges, dim=0).cuda() - kv_ranges = [torch.arange(max_full_len) for _ in full_seq_lens] - kv_ranges = torch.stack(kv_ranges, 0).cuda() - mask = kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, - None, - None] - return mask.float() * neg_val + max_kv_len = kv_seqlens.max().item() + if causal: + seq_ranges = [torch.arange(max_seq_len) for _ in q_seqlens] + for r, l in zip(seq_ranges, q_seqlens): + r[l:] = -max_kv_len + seq_ranges = torch.stack(seq_ranges, dim=0).cuda() + kv_ranges = [torch.arange(max_kv_len) for _ in kv_seqlens] + kv_ranges = torch.stack(kv_ranges, 0).cuda() + mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > + history_lens[:, None, None]) + return mask.float() * neg_val + else: + q_mask = torch.arange(max_seq_len)[None].cuda() < q_seqlens[:, None] + k_mask = torch.arange(max_kv_len)[None].cuda() < kv_seqlens[:, None] + mask = q_mask[:, :, None] & k_mask[:, None, :] + + return (~mask).float() * neg_val def _naive_attention(batched_q, batched_kv, bias): @@ -100,6 +106,10 @@ def num_heads_q(self, request): def num_heads_k(self, request): yield request.param + @pytest.fixture + def causal(self, request): + yield request.param + @pytest.fixture def q_seqlens(self, request): yield torch.tensor(request.param, device='cuda') @@ -138,8 +148,8 @@ def batched_kv(self, q_seqlens, history_lens, num_heads_k, head_dim_k, head_dim_v, dtype): torch.manual_seed(123) batch_size = len(q_seqlens) - full_seq_lens = q_seqlens + history_lens - max_seq_len = full_seq_lens.max().item() + kv_seqlens = q_seqlens + history_lens + max_seq_len = kv_seqlens.max().item() k = torch.rand(batch_size, max_seq_len, num_heads_k, @@ -167,9 +177,9 @@ def conti_kv(self, kv_seqlens, batched_kv): yield (conti_k, conti_v) @pytest.fixture - def mask(self, q_seqlens, history_lens): + def mask(self, q_seqlens, history_lens, causal): neg_val = -1e30 - yield _make_bias(q_seqlens, history_lens, neg_val) + yield _make_bias(q_seqlens, history_lens, neg_val, causal) @pytest.fixture def gt(self, batched_q, batched_kv, mask): @@ -183,11 +193,13 @@ def conti_gt(self, gt, q_seqlens): @pytest.mark.parametrize('head_dim_v', [32], indirect=True) @pytest.mark.parametrize('num_heads_q', [8, 2], indirect=True) @pytest.mark.parametrize('num_heads_k', [2], indirect=True) + @pytest.mark.parametrize('causal', [True, False], indirect=True) @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [([30, 50, 70, 90], [50, 40, 30, 20])], indirect=True) def test_flash_attention(self, conti_q, conti_kv, q_start_loc, q_seqlens, - kv_start_loc, kv_seqlens, head_dim_v, conti_gt): + kv_start_loc, kv_seqlens, head_dim_v, causal, + conti_gt): from lmdeploy.pytorch.kernels.cuda.flashattention import \ flash_attention_fwd max_seq_len = q_seqlens.max().item() @@ -202,7 +214,8 @@ def test_flash_attention(self, conti_q, conti_kv, q_start_loc, q_seqlens, q_seqlens=q_seqlens, kv_start_loc=kv_start_loc, kv_seqlens=kv_seqlens, - max_seqlen=max_seq_len) + max_seqlen=max_seq_len, + causal=causal) torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5) @pytest.fixture diff --git a/tests/test_lmdeploy/test_model.py b/tests/test_lmdeploy/test_model.py index 3b78053a74..0e53283a87 100644 --- a/tests/test_lmdeploy/test_model.py +++ b/tests/test_lmdeploy/test_model.py @@ -220,7 +220,7 @@ def test_llama3_1(): }, }] actual_prompt = model.messages2prompt(messages, tools=tools) - expected_prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 23 Jul 2024\n\n# Tool Instructions\n- Always execute python code in messages that you share.\n- When looking for real time information use relevant functions if available else fallback to brave_search\n\n\n\nYou have access to the following functions:\n\nUse the function \'spotify_trending_songs\' to: Get top trending songs on Spotify\n{"name": "spotify_trending_songs", "description": "Get top trending songs on Spotify", "parameters": {"n": {"param_type": "int", "description": "Number of trending songs to get", "required": true}}}\n\n\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => ` a JSON dict with the function argument name as key and function argument value as value.\nend_tag => ``\n\nHere is an example,\n{"example_name": "example_value"}\n\nReminder:\n- Function calls MUST follow the specified format\n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line"\n- Always add your sources when using search results to answer the user query\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCan you check the top 5 trending songs on spotify?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa + expected_prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n# Tool Instructions\n- Always execute python code in messages that you share.\n- When looking for real time information use relevant functions if available else fallback to brave_search\n\n\n\nYou have access to the following functions:\n\nUse the function \'spotify_trending_songs\' to: Get top trending songs on Spotify\n{"name": "spotify_trending_songs", "description": "Get top trending songs on Spotify", "parameters": {"n": {"param_type": "int", "description": "Number of trending songs to get", "required": true}}}\n\n\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => ` a JSON dict with the function argument name as key and function argument value as value.\nend_tag => ``\n\nHere is an example,\n{"example_name": "example_value"}\n\nReminder:\n- Function calls MUST follow the specified format\n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line"\n- Always add your sources when using search results to answer the user query\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCan you check the top 5 trending songs on spotify?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa assert actual_prompt == expected_prompt diff --git a/tests/test_lmdeploy/test_vl_encode.py b/tests/test_lmdeploy/test_vl/test_vl_encode.py similarity index 100% rename from tests/test_lmdeploy/test_vl_encode.py rename to tests/test_lmdeploy/test_vl/test_vl_encode.py diff --git a/tests/test_lmdeploy/test_vl_template.py b/tests/test_lmdeploy/test_vl_template.py deleted file mode 100644 index cf8abf9e44..0000000000 --- a/tests/test_lmdeploy/test_vl_template.py +++ /dev/null @@ -1,132 +0,0 @@ -import PIL - -from lmdeploy.model import MODELS -from lmdeploy.vl.constants import IMAGE_TOKEN -from lmdeploy.vl.templates import VLChatTemplateWrapper - - -def test_prompt_to_messages(): - model = MODELS.get('llava-v1')() - templtae = VLChatTemplateWrapper(model) - out = templtae.prompt_to_messages('hi') - assert isinstance(out, list) and isinstance(out[0], dict) - im = PIL.Image.new(mode='RGB', size=(200, 200)) - out = templtae.prompt_to_messages(('hi', [im])) - assert isinstance(out, list) and isinstance(out[0], dict) - - -def test_messages2prompt(): - model = MODELS.get('llava-v1')() - templtae = VLChatTemplateWrapper(model) - messages = [ - dict(role='user', - content=[ - dict(type='text', text='q1'), - dict(type='image_url', image_url=dict(url='xxx')) - ]) - ] - prompt = templtae.messages2prompt(messages) - assert isinstance(prompt, str) - assert prompt.count(IMAGE_TOKEN) == 1 - expected = ( - 'A chat between a curious human and an artificial intelligence ' - 'assistant. The assistant gives helpful, detailed, and polite ' - "answers to the human's questions. USER: " - '\nq1 ASSISTANT:') - assert prompt == expected - - messages.append({'role': 'assistant', 'content': 'a1'}) - messages.append({'role': 'user', 'content': 'q2'}) - prompt = templtae.messages2prompt(messages) - expected = ( - 'A chat between a curious human and an artificial intelligence ' - 'assistant. The assistant gives helpful, detailed, and polite ' - "answers to the human's questions. USER: " - '\nq1 ASSISTANT: a1USER: q2 ASSISTANT:') - assert prompt == expected - - -def test_internvl2_conv(): - # https://huggingface.co/OpenGVLab/InternVL2-8B/blob/3bfd3664dea4f3da628785f5125d30f889701253/conversation.py - from transformers.dynamic_module_utils import get_class_from_dynamic_module - get_conv_template = get_class_from_dynamic_module( - 'conversation.get_conv_template', 'OpenGVLab/InternVL2-8B') - template = get_conv_template('internlm2-chat') - question1 = 'question1' - template.append_message(template.roles[0], question1) - template.append_message(template.roles[1], None) - model = MODELS.get('internvl2-internlm2')() - messages = [dict(role='user', content=question1)] - assert template.get_prompt() == model.messages2prompt(messages) - - answer1 = 'answer1' - template.messages[-1][1] = answer1 - question2 = 'question2' - template.append_message(template.roles[0], question2) - template.append_message(template.roles[1], None) - messages.append(dict(role='assistant', content=answer1)) - messages.append(dict(role='user', content=question2)) - assert template.get_prompt() == model.messages2prompt(messages) - - -def test_llava_conv_chatml_direct(): - model = MODELS.get('llava-chatml')() - templtae = VLChatTemplateWrapper(model) - messages = [ - dict(role='user', - content=[ - dict(type='text', text='q1'), - dict(type='image_url', image_url=dict(url='xxx')) - ]) - ] - - prompt = templtae.messages2prompt(messages) - expected = ('<|im_start|>system\nAnswer the questions.<|im_end|>' - '<|im_start|>user\n\nq1<|im_end|>' - '<|im_start|>assistant\n') - assert prompt == expected - - messages.append({'role': 'assistant', 'content': 'a1'}) - messages.append({'role': 'user', 'content': 'q2'}) - prompt = templtae.messages2prompt(messages) - expected = ('<|im_start|>system\nAnswer the questions.<|im_end|>' - '<|im_start|>user\n\nq1<|im_end|>' - '<|im_start|>assistant\na1<|im_end|>' - '<|im_start|>user\nq2<|im_end|>' - '<|im_start|>assistant\n') - assert prompt == expected - - -def test_custom_image_token(): - from lmdeploy.vl.templates import DeepSeekVLChatTemplateWrapper - model = MODELS.get('deepseek-vl')() - template = DeepSeekVLChatTemplateWrapper(model) - - def create_user(query: str): - item = dict(role='user', content=[dict(type='text', text=query)]) - num = query.count(IMAGE_TOKEN) - for _ in range(num): - item['content'].append( - dict(type='image_url', image_url=dict(url='xxx'))) - return item - - def create_assistant(response: str): - return dict(role='assistant', content=response) - - messages = [create_user(f'{IMAGE_TOKEN} q1')] - prompt = template.messages2prompt(messages) - expected = ('You are a helpful language and vision assistant. You are able' - ' to understand the visual content that the user provides, and' - ' assist the user with a variety of tasks using natural ' - 'language.\n\nUser: q1\n\nAssistant:') - assert prompt == expected - - messages.append(create_assistant('a1')) - messages.append(create_user(f'q2 {IMAGE_TOKEN}')) - prompt = template.messages2prompt(messages) - expected = ('You are a helpful language and vision assistant. You are able' - ' to understand the visual content that the user provides, and' - ' assist the user with a variety of tasks using natural ' - 'language.\n\nUser: q1\n\nAssistant: ' - 'a1<|end▁of▁sentence|>User: q2 \n\nAssistant:') - assert prompt == expected