diff --git a/.github/workflows/pr_ete_test.yml b/.github/workflows/pr_ete_test.yml
index 3a19ebe870..2d1c4b63f5 100644
--- a/.github/workflows/pr_ete_test.yml
+++ b/.github/workflows/pr_ete_test.yml
@@ -10,7 +10,7 @@ on:
- "3rdparty/**"
- "lmdeploy/**"
- "requirements/**"
- - "requirements.txt"
+ - "requirements_cuda.txt"
- "CMakeLists.txt"
- "setup.py"
workflow_dispatch:
@@ -68,7 +68,7 @@ jobs:
export PATH=$PATH:/usr/local/openmpi/bin
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/openmpi/lib
python3 -m pip install cmake packaging wheel transformers_stream_generator transformers datasets openai einops timm decord
- python3 -m pip install -r requirements.txt -r requirements/test.txt -r requirements/build.txt
+ python3 -m pip install -r requirements_cuda.txt -r requirements/test.txt -r requirements/build.txt
mkdir -p build && cd build &&\
sh ../generate.sh &&\
ninja -j$(nproc) && ninja install &&\
diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml
index ec6db0682d..3a459050ec 100644
--- a/.github/workflows/unit-test.yml
+++ b/.github/workflows/unit-test.yml
@@ -10,7 +10,7 @@ on:
- "3rdparty/**"
- "lmdeploy/**"
- "requirements/**"
- - "requirements.txt"
+ - "requirements_cuda.txt"
- "CMakeLists.txt"
- "setup.py"
push:
@@ -24,7 +24,7 @@ on:
- "3rdparty/**"
- "lmdeploy/**"
- "requirements/**"
- - "requirements.txt"
+ - "requirements_cuda.txt"
- "CMakeLists.txt"
- "setup.py"
tags:
@@ -39,6 +39,7 @@ jobs:
options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e CUDA_VISIBLE_DEVICES=2,3 --pull never"
volumes:
- /nvme/share_data/github-actions/pip-cache:/root/.cache/pip
+ - /nvme/share_data/github-actions/hf_home:/root/.cache/huggingface
- /nvme/share_data/github-actions/packages:/root/packages
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
@@ -78,7 +79,7 @@ jobs:
python3 -m pip install pynvml packaging protobuf transformers_stream_generator
# manually install flash attn
python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
- python3 -m pip install -r requirements.txt -r requirements/test.txt
+ python3 -m pip install -r requirements_cuda.txt -r requirements/test.txt
python3 -m pip install .
- name: Check env
run: |
diff --git a/README.md b/README.md
index d160338aa6..8ef7b7994f 100644
--- a/README.md
+++ b/README.md
@@ -125,6 +125,8 @@ For detailed inference benchmarks in more devices and more settings, please refe
Qwen1.5 (0.5B - 110B)
Qwen1.5 - MoE (0.5B - 72B)
Qwen2 (0.5B - 72B)
+ Qwen2-MoE (57BA14B)
+ Qwen2.5 (0.5B - 32B)
Baichuan (7B)
Baichuan2 (7B-13B)
Code Llama (7B - 34B)
@@ -136,6 +138,7 @@ For detailed inference benchmarks in more devices and more settings, please refe
Mistral (7B)
DeepSeek-MoE (16B)
DeepSeek-V2 (16B, 236B)
+ DeepSeek-V2.5 (236B)
Mixtral (8x7B, 8x22B)
Gemma (2B - 7B)
Dbrx (132B)
diff --git a/README_ja.md b/README_ja.md
index fda176229e..77badaac36 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -122,6 +122,8 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
Qwen1.5 (0.5B - 110B)
Qwen1.5 - MoE (0.5B - 72B)
Qwen2 (0.5B - 72B)
+ Qwen2-MoE (57BA14B)
+ Qwen2.5 (0.5B - 32B)
Baichuan (7B)
Baichuan2 (7B-13B)
Code Llama (7B - 34B)
@@ -133,6 +135,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
Mistral (7B)
DeepSeek-MoE (16B)
DeepSeek-V2 (16B, 236B)
+ DeepSeek-V2.5 (236B)
Mixtral (8x7B, 8x22B)
Gemma (2B - 7B)
Dbrx (132B)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 6c24b2e500..9f3cd40a64 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -126,6 +126,8 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
Qwen1.5 (0.5B - 110B)
Qwen1.5 - MoE (0.5B - 72B)
Qwen2 (0.5B - 72B)
+ Qwen2-MoE (57BA14B)
+ Qwen2.5 (0.5B - 32B)
Baichuan (7B)
Baichuan2 (7B-13B)
Code Llama (7B - 34B)
@@ -137,6 +139,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
Mistral (7B)
DeepSeek-MoE (16B)
DeepSeek-V2 (16B, 236B)
+ DeepSeek-V2.5 (236B)
Mixtral (8x7B, 8x22B)
Gemma (2B - 7B)
Dbrx (132B)
diff --git a/autotest/utils/config_utils.py b/autotest/utils/config_utils.py
index 24b4a3f8cd..dd8db1ccc4 100644
--- a/autotest/utils/config_utils.py
+++ b/autotest/utils/config_utils.py
@@ -97,7 +97,7 @@ def get_all_model_list(tp_num: int = None,
model_type=model_type):
if case not in case_list:
case_list.append(case)
- return [x for x in case_list if 'w8a8' not in x]
+ return case_list
def get_quantization_model_list(type):
diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py
index d41f886227..3a343e8f5b 100644
--- a/benchmark/profile_throughput.py
+++ b/benchmark/profile_throughput.py
@@ -377,15 +377,17 @@ def main():
requests = sample_requests(args.dataset, args.num_prompts,
engine.tokenizer)
- engine.process_request(requests,
- temperature=args.temperature,
- top_p=args.top_p,
- top_k=args.top_k,
- concurrency=args.concurrency,
- stream_output=not args.no_stream_output,
- skip_tokenize=args.skip_tokenize,
- skip_detokenize=args.skip_detokenize,
- cancel_rate=args.cancel_rate)
+ engine.process_request(
+ requests,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ top_k=args.top_k,
+ concurrency=args.concurrency
+ if args.concurrency < args.num_prompts else args.num_prompts,
+ stream_output=not args.no_stream_output,
+ skip_tokenize=args.skip_tokenize,
+ skip_detokenize=args.skip_detokenize,
+ cancel_rate=args.cancel_rate)
if __name__ == '__main__':
diff --git a/docker/Dockerfile b/docker/Dockerfile
index caa58ee637..24b2b055da 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -10,9 +10,6 @@ FROM ${CUDA_VERSION} AS final
ARG PYTHON_VERSION=3.10
-ARG TORCH_VERSION=2.3.0
-ARG TORCHVISION_VERSION=0.18.0
-
RUN apt-get update -y && apt-get install -y software-properties-common wget vim git curl openssh-server ssh sudo &&\
curl https://sh.rustup.rs -sSf | sh -s -- -y &&\
add-apt-repository ppa:deadsnakes/ppa -y && apt-get update -y && apt-get install -y --no-install-recommends \
@@ -43,7 +40,6 @@ ENV LD_LIBRARY_PATH=/usr/local/nccl/lib:$LD_LIBRARY_PATH
RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install --upgrade pip setuptools==69.5.1 &&\
- python3 -m pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} --index-url https://download.pytorch.org/whl/${CUDA_VERSION_SHORT} &&\
python3 -m pip install cmake packaging wheel
ENV NCCL_LAUNCH_MODE=GROUP
@@ -54,7 +50,7 @@ COPY . /opt/lmdeploy
WORKDIR /opt/lmdeploy
RUN --mount=type=cache,target=/root/.cache/pip cd /opt/lmdeploy &&\
- python3 -m pip install -r requirements.txt &&\
+ python3 -m pip install -r requirements_cuda.txt --extra-index-url https://download.pytorch.org/whl/${CUDA_VERSION_SHORT} &&\
mkdir -p build && cd build &&\
sh ../generate.sh &&\
ninja -j$(nproc) && ninja install &&\
diff --git a/docker/Dockerfile_aarch64_ascend b/docker/Dockerfile_aarch64_ascend
index 1c9591197b..ecc2d1334e 100644
--- a/docker/Dockerfile_aarch64_ascend
+++ b/docker/Dockerfile_aarch64_ascend
@@ -122,4 +122,4 @@ WORKDIR /opt/lmdeploy
RUN --mount=type=cache,target=/root/.cache/pip \
sed -i '/triton/d' requirements/runtime.txt && \
- pip3 install -v --no-build-isolation -e .
+ LMDEPLOY_TARGET_DEVICE=ascend pip3 install -v --no-build-isolation -e .
diff --git a/docs/en/get_started/ascend/get_started.md b/docs/en/get_started/ascend/get_started.md
index 23b86afa61..d104477ca1 100644
--- a/docs/en/get_started/ascend/get_started.md
+++ b/docs/en/get_started/ascend/get_started.md
@@ -136,3 +136,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu
```
Please check [supported_models](../../supported_models/supported_models.md) before use this feature.
+
+### int8 KV-cache Quantization
+
+Ascend backend has supported offline int8 KV-cache Quantization on eager mode.
+
+Please refer this [doc](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md) for details.
diff --git a/docs/en/get_started/installation.md b/docs/en/get_started/installation.md
index b3e8bb8abd..c00111c2ab 100644
--- a/docs/en/get_started/installation.md
+++ b/docs/en/get_started/installation.md
@@ -23,7 +23,7 @@ pip install lmdeploy
The default prebuilt package is compiled on **CUDA 12**. If CUDA 11+ (>=11.3) is required, you can install lmdeploy by:
```shell
-export LMDEPLOY_VERSION=0.6.3
+export LMDEPLOY_VERSION=0.6.4
export PYTHON_VERSION=38
pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118
```
diff --git a/docs/en/llm/api_server.md b/docs/en/llm/api_server.md
index 285b0e32ff..274ec2ff25 100644
--- a/docs/en/llm/api_server.md
+++ b/docs/en/llm/api_server.md
@@ -249,6 +249,57 @@ curl http://{server_ip}:{server_port}/v1/chat/interactive \
lmdeploy serve gradio api_server_url --server-name ${gradio_ui_ip} --server-port ${gradio_ui_port}
```
+## Launch multiple api servers
+
+Following are two steps to launch multiple api servers through torchrun. Just create a python script with the following codes.
+
+1. Launch the proxy server through `lmdeploy serve proxy`. Get the correct proxy server url.
+2. Launch the script through `torchrun --nproc_per_node 2 script.py InternLM/internlm2-chat-1_8b --proxy_url http://{proxy_node_name}:{proxy_node_port}`.**Note**: Please do not use `0.0.0.0:8000` here, instead, we input the real ip name, `11.25.34.55:8000` for example.
+
+```python
+import os
+import socket
+from typing import List, Literal
+
+import fire
+
+
+def get_host_ip():
+ try:
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ s.connect(('8.8.8.8', 80))
+ ip = s.getsockname()[0]
+ finally:
+ s.close()
+ return ip
+
+
+def main(model_path: str,
+ tp: int = 1,
+ proxy_url: str = 'http://0.0.0.0:8000',
+ port: int = 23333,
+ backend: Literal['turbomind', 'pytorch'] = 'turbomind'):
+ local_rank = int(os.environ.get('LOCAL_RANK', -1))
+ world_size = int(os.environ.get('WORLD_SIZE', -1))
+ local_ip = get_host_ip()
+ if isinstance(port, List):
+ assert len(port) == world_size
+ port = port[local_rank]
+ else:
+ port += local_rank * 10
+ if (world_size - local_rank) % tp == 0:
+ rank_list = ','.join([str(local_rank + i) for i in range(tp)])
+ command = f'CUDA_VISIBLE_DEVICES={rank_list} lmdeploy serve api_server {model_path} '\
+ f'--server-name {local_ip} --server-port {port} --tp {tp} '\
+ f'--proxy-url {proxy_url} --backend {backend}'
+ print(f'running command: {command}')
+ os.system(command)
+
+
+if __name__ == '__main__':
+ fire.Fire(main)
+```
+
## FAQ
1. When user got `"finish_reason":"length"`, it means the session is too long to be continued. The session length can be
diff --git a/docs/en/multi_modal/llava.md b/docs/en/multi_modal/llava.md
index 8f052227d5..c374b67121 100644
--- a/docs/en/multi_modal/llava.md
+++ b/docs/en/multi_modal/llava.md
@@ -6,11 +6,17 @@ LMDeploy supports the following llava series of models, which are detailed in th
| :----------------------------------: | :--: | :------------------------: |
| llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch |
| llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch |
-| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind, PyTorch |
-| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind, PyTorch |
+| llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch |
+| llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch |
+| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind |
+| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind |
The next chapter demonstrates how to deploy an Llava model using LMDeploy, with [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) as an example.
+```{note}
+PyTorch engine removes the support of original llava models after v0.6.4. Please use their corresponding transformers models instead, which can be found in https://huggingface.co/llava-hf
+```
+
## Installation
Please install LMDeploy by following the [installation guide](../get_started/installation.md).
diff --git a/docs/en/multi_modal/qwen2_vl.md b/docs/en/multi_modal/qwen2_vl.md
index 8b59f84545..fd9f02abaa 100644
--- a/docs/en/multi_modal/qwen2_vl.md
+++ b/docs/en/multi_modal/qwen2_vl.md
@@ -4,7 +4,7 @@ LMDeploy supports the following Qwen-VL series of models, which are detailed in
| Model | Size | Supported Inference Engine |
| :----------: | :----: | :------------------------: |
-| Qwen-VL-Chat | - | TurboMind, Pytorch |
+| Qwen-VL-Chat | - | TurboMind |
| Qwen2-VL | 2B, 7B | PyTorch |
The next chapter demonstrates how to deploy an Qwen-VL model using LMDeploy, with [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) as an example.
diff --git a/docs/en/quantization/w4a16.md b/docs/en/quantization/w4a16.md
index 0aa1e17a5b..c36c3736c6 100644
--- a/docs/en/quantization/w4a16.md
+++ b/docs/en/quantization/w4a16.md
@@ -128,3 +128,7 @@ We benchmarked the Llama-2-7B-chat and Llama-2-13B-chat models with 4-bit quanti
| ---------------- | ------- | ------- | --------- |
| Llama-2-7B-chat | 112.9 | 159.4 | 206.4 |
| Llama-2-13B-chat | N/A | 90.7 | 115.8 |
+
+## FAQs
+
+1. Out of Memory error during quantization due to insufficient GPU memory: This can be addressed by reducing the parameter `--calib-seqlen`, increasing the parameter `--calib-samples`, and set `--batch-size` to 1.
diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md
index 469ece487f..dd8ceb4ffa 100644
--- a/docs/en/supported_models/supported_models.md
+++ b/docs/en/supported_models/supported_models.md
@@ -10,7 +10,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes |
| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
-| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes |
+| Llama3.2 | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes |
| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes |
@@ -18,9 +18,13 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
| InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes |
| Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes |
| Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes |
-| Qwen2 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes |
+| Qwen2 | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes |
+| Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes |
+| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes |
| Mistral | 7B | LLM | Yes | Yes | Yes | No |
| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes |
+| DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No |
+| DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No |
| Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
| DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
| Baichuan | 7B | LLM | Yes | Yes | Yes | Yes |
@@ -29,7 +33,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes |
| LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes |
| InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes |
-| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes | Yes | Yes |
+| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes |
| ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes |
| MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes |
| MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes |
@@ -41,7 +45,8 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
"-" means not verified yet.
```{note}
-The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference.
+* The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference.
+* When the head_dim of a model is not 128, such as llama3.2-1B, qwen2-0.5B and internvl2-1B, turbomind doesn't support its kv cache 4/8 bit quantization and inference
```
## PyTorchEngine on CUDA Platform
@@ -68,11 +73,13 @@ The TurboMind engine doesn't support window attention. Therefore, for models tha
| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes |
| QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No |
| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
+| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | No |
| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No |
+| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No |
| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No |
-| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | Yes | Yes |
+| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes |
| Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No |
| Dbrx | 132B | LLM | Yes | Yes | Yes | No | No |
| StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No |
@@ -81,7 +88,7 @@ The TurboMind engine doesn't support window attention. Therefore, for models tha
| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - |
| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - |
| LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | - | - |
-| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | Yes | Yes |
+| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes |
| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | - | - |
| Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | - | - |
| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - |
diff --git a/docs/zh_cn/get_started/ascend/get_started.md b/docs/zh_cn/get_started/ascend/get_started.md
index b137c458be..9f0a7b1f90 100644
--- a/docs/zh_cn/get_started/ascend/get_started.md
+++ b/docs/zh_cn/get_started/ascend/get_started.md
@@ -133,3 +133,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu
```
支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)。
+
+### int8 KV-cache 量化
+
+昇腾后端现在支持了在eager模式下的离线int8 KV-cache量化。
+
+详细使用方式请请参考这篇[文章](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md)。
diff --git a/docs/zh_cn/get_started/installation.md b/docs/zh_cn/get_started/installation.md
index 12562c51d5..0213fa6d15 100644
--- a/docs/zh_cn/get_started/installation.md
+++ b/docs/zh_cn/get_started/installation.md
@@ -23,7 +23,7 @@ pip install lmdeploy
默认的预构建包是在 **CUDA 12** 上编译的。如果需要 CUDA 11+ (>=11.3),你可以使用以下命令安装 lmdeploy:
```shell
-export LMDEPLOY_VERSION=0.6.3
+export LMDEPLOY_VERSION=0.6.4
export PYTHON_VERSION=38
pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118
```
diff --git a/docs/zh_cn/llm/api_server.md b/docs/zh_cn/llm/api_server.md
index d6c0c42aef..8bb91c619e 100644
--- a/docs/zh_cn/llm/api_server.md
+++ b/docs/zh_cn/llm/api_server.md
@@ -258,6 +258,89 @@ curl http://{server_ip}:{server_port}/v1/chat/interactive \
}'
```
+## 同时启动多个 api_server
+
+两步直接启动多机多卡服务。先用下面的代码创建一个启动脚本。然后:
+
+1. 启动代理服务 `lmdeploy serve proxy`。
+2. torchrun 启动脚本 `torchrun --nproc_per_node 2 script.py InternLM/internlm2-chat-1_8b --proxy_url http://{proxy_node_name}:{proxy_node_port}`. **注意**: 多机多卡不要用默认 url `0.0.0.0:8000`,我们需要输入真实ip对应的地址,如:`11.25.34.55:8000`。多机情况下,因为不需要子节点间的通信,所以并不需要用户指定 torchrun 的 `--nnodes` 等参数,只要能保证每个节点执行一次单节点的 torchrun 就行。
+
+```python
+import os
+import socket
+from typing import List, Literal
+
+import fire
+
+
+def get_host_ip():
+ try:
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ s.connect(('8.8.8.8', 80))
+ ip = s.getsockname()[0]
+ finally:
+ s.close()
+ return ip
+
+
+def main(model_path: str,
+ tp: int = 1,
+ proxy_url: str = 'http://0.0.0.0:8000',
+ port: int = 23333,
+ backend: Literal['turbomind', 'pytorch'] = 'turbomind'):
+ local_rank = int(os.environ.get('LOCAL_RANK', -1))
+ world_size = int(os.environ.get('WORLD_SIZE', -1))
+ local_ip = get_host_ip()
+ if isinstance(port, List):
+ assert len(port) == world_size
+ port = port[local_rank]
+ else:
+ port += local_rank * 10
+ if (world_size - local_rank) % tp == 0:
+ rank_list = ','.join([str(local_rank + i) for i in range(tp)])
+ command = f'CUDA_VISIBLE_DEVICES={rank_list} lmdeploy serve api_server {model_path} '\
+ f'--server-name {local_ip} --server-port {port} --tp {tp} '\
+ f'--proxy-url {proxy_url} --backend {backend}'
+ print(f'running command: {command}')
+ os.system(command)
+
+
+if __name__ == '__main__':
+ fire.Fire(main)
+```
+
+### 示例
+
+为了进一步展示如何在集群环境中使用多机多卡服务。下面提供一个在火山云的用例:
+
+```shell
+#!/bin/bash
+# 激活 conda 环境
+source /path/to/your/home/miniconda3/bin/activate /path/to/your/home/miniconda3/envs/your_env
+export HOME=/path/to/your/home
+# 获取主节点IP地址(假设 MLP_WORKER_0_HOST 是主节点的IP)
+MASTER_IP=${MLP_WORKER_0_HOST}
+# 检查是否为主节点
+if [ "${MLP_ROLE_INDEX}" -eq 0 ]; then
+ # 启动 lmdeploy serve proxy 并放入后台
+ echo "Starting lmdeploy serve proxy on master node..."
+ PROXY_PORT=8000
+ lmdeploy serve proxy --server-name ${MASTER_IP} --server-port ${PROXY_PORT} &
+else
+ # 这里我们默认调度平台同时启动了所有机器,否则要sleep一会,等待 proxy 启动成功
+ echo "Not starting lmdeploy serve proxy on worker node ${MLP_ROLE_INDEX}."
+fi
+# 启动 torchrun 并放入后台
+# 再次强调多机环境下并不需要传--nnodes 或者 --master-addr 等参数,相当于每个机器上执行一次单节点的 torchrun 即可。
+torchrun \
+--nproc_per_node=${MLP_WORKER_GPU} \
+/path/to/script.py \
+InternLM/internlm2-chat-1_8b 8 http://${MASTER_IP}:${PROXY_PORT}
+# 打印主机的IP地址
+echo "Host IP addresses:"
+hostname -I
+```
+
## 接入 WebUI
LMDeploy 提供 gradio 和 [OpenAOE](https://github.com/InternLM/OpenAOE) 两种方式,为 api_server 接入 WebUI。
diff --git a/docs/zh_cn/multi_modal/llava.md b/docs/zh_cn/multi_modal/llava.md
index c40f37308a..6538d1b861 100644
--- a/docs/zh_cn/multi_modal/llava.md
+++ b/docs/zh_cn/multi_modal/llava.md
@@ -6,11 +6,17 @@ LMDeploy 支持以下 LLaVA 系列模型,具体如下表所示:
| :----------------------------------: | :--: | :----------------: |
| llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch |
| llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch |
-| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind, PyTorch |
-| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind, PyTorch |
+| llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch |
+| llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch |
+| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind |
+| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind |
接下来的章节将演示如何使用 LMDeploy 部署 LLaVA 模型,并以 [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) 为例。
+```{note}
+自 0.6.4 之后,PyTorch 引擎移除了对 llava 原始模型的支持。我们建议使用它们对应的 transformers 格式的模型。这些模型可以在 https://huggingface.co/llava-hf 中找到
+```
+
## 安装
请按照[安装指南](../get_started/installation.md)安装 LMDeploy。
diff --git a/docs/zh_cn/multi_modal/qwen2_vl.md b/docs/zh_cn/multi_modal/qwen2_vl.md
index f62d2de74c..7cb7efe93b 100644
--- a/docs/zh_cn/multi_modal/qwen2_vl.md
+++ b/docs/zh_cn/multi_modal/qwen2_vl.md
@@ -4,7 +4,7 @@ LMDeploy 支持 Qwen-VL 系列模型,具体如下:
| Model | Size | Supported Inference Engine |
| :----------: | :----: | :------------------------: |
-| Qwen-VL-Chat | - | TurboMind, Pytorch |
+| Qwen-VL-Chat | - | TurboMind |
| Qwen2-VL | 2B, 7B | PyTorch |
本文将以[Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)为例,演示使用 LMDeploy 部署 Qwen2-VL 系列模型的方法
diff --git a/docs/zh_cn/quantization/w4a16.md b/docs/zh_cn/quantization/w4a16.md
index d69a8a23d2..3cea164dd9 100644
--- a/docs/zh_cn/quantization/w4a16.md
+++ b/docs/zh_cn/quantization/w4a16.md
@@ -131,3 +131,8 @@ lmdeploy serve api_client http://0.0.0.0:23333
| ---------------- | ------- | ------- | --------- |
| Llama-2-7B-chat | 112.9 | 159.4 | 206.4 |
| Llama-2-13B-chat | N/A | 90.7 | 115.8 |
+
+## 快速问答
+
+1. 量化时出现 Out of Memory 显存不够:可以通过减小传参 `--calib-seqlen`,增大传参 `--calib-samples`,并使用 `--batch-size` 为 1。
+2. 量化时,无法链接huggingface并下载数据集。可以尝试使用镜像,`export HF_ENDPOINT=https://hf-mirror.com`。
diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md
index d734523282..3ec3688e1b 100644
--- a/docs/zh_cn/supported_models/supported_models.md
+++ b/docs/zh_cn/supported_models/supported_models.md
@@ -10,7 +10,7 @@
| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes |
| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
-| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes |
+| Llama3.2 | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes |
| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes |
@@ -18,9 +18,13 @@
| InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes |
| Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes |
| Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes |
-| Qwen2 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes |
+| Qwen2 | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes |
+| Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes |
+| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes |
| Mistral | 7B | LLM | Yes | Yes | Yes | No |
| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes |
+| DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No |
+| DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No |
| Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
| DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
| Baichuan | 7B | LLM | Yes | Yes | Yes | Yes |
@@ -29,7 +33,7 @@
| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes |
| LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes |
| InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes |
-| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes | Yes | Yes |
+| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes |
| ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes |
| MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes |
| MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes |
@@ -41,7 +45,8 @@
“-” 表示还没有验证。
```{note}
-turbomind 引擎不支持 window attention。所以,对于应用了 window attention,并开启了对应的开关"use_sliding_window"的模型,比如 Mistral、Qwen1.5 等,在推理时,请选择 pytorch engine
+* turbomind 引擎不支持 window attention。所以,对于应用了 window attention,并开启了对应的开关"use_sliding_window"的模型,比如 Mistral、Qwen1.5 等,在推理时,请选择 pytorch engine
+* 当模型的 head_dim 非 128 时,turbomind 不支持它的 kv cache 4/8 bit 量化和推理。比如,llama3.2-1B,qwen2-0.5B,internvl2-1B 等等
```
## PyTorchEngine CUDA 平台
@@ -68,11 +73,13 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att
| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes |
| QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No |
| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
+| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | No |
| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No |
+| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No |
| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No |
-| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | Yes | Yes |
+| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes |
| Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No |
| Dbrx | 132B | LLM | Yes | Yes | Yes | No | No |
| StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No |
@@ -81,7 +88,7 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att
| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - |
| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - |
| LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | - | - |
-| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | Yes | Yes |
+| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes |
| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | - | - |
| Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | - | - |
| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - |
@@ -94,7 +101,7 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att
| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - |
```{note}
-* Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead.
+* 目前,Mono-InternVL不支持FP16,因为数值不稳定。请改用BF16。
```
## PyTorchEngine 华为昇腾平台
diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py
index ce5cbd98ff..760a82b1c9 100644
--- a/lmdeploy/archs.py
+++ b/lmdeploy/archs.py
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
-from typing import Literal, Optional, Union
+from typing import Dict, List, Literal, Optional, Union
from transformers import AutoConfig
@@ -128,7 +128,8 @@ def check_vl_llm(config: dict) -> bool:
return True
elif arch == 'MultiModalityCausalLM' and 'language_config' in config:
return True
- elif arch == 'ChatGLMModel' and 'vision_config' in config:
+ elif arch in ['ChatGLMModel', 'ChatGLMForConditionalGeneration'
+ ] and 'vision_config' in config:
return True
elif arch in supported_archs:
return True
@@ -193,3 +194,22 @@ def get_model_arch(model_path: str):
raise RuntimeError(
f'Could not find model architecture from config: {_cfg}')
return arch, cfg
+
+
+def search_nested_config(config, key):
+ """Recursively searches for the value associated with the given key in a
+ nested configuration of a model."""
+ if isinstance(config, Dict):
+ for k, v in config.items():
+ if k == key:
+ return v
+ if isinstance(v, (Dict, List)):
+ result = search_nested_config(v, key)
+ if result is not None:
+ return result
+ elif isinstance(config, List):
+ for item in config:
+ result = search_nested_config(item, key)
+ if result is not None:
+ return result
+ return None
diff --git a/lmdeploy/cli/lite.py b/lmdeploy/cli/lite.py
index d76d6a5f34..236e022b34 100644
--- a/lmdeploy/cli/lite.py
+++ b/lmdeploy/cli/lite.py
@@ -35,6 +35,7 @@ def add_parser_auto_awq():
ArgumentHelper.calib_seqlen(parser)
ArgumentHelper.calib_batchsize(parser)
ArgumentHelper.calib_search_scale(parser)
+ ArgumentHelper.dtype(parser)
parser.add_argument(
'--device',
type=str,
@@ -71,6 +72,7 @@ def add_parser_auto_gptq():
ArgumentHelper.calib_samples(parser)
ArgumentHelper.calib_seqlen(parser)
ArgumentHelper.calib_batchsize(parser)
+ ArgumentHelper.dtype(parser)
parser.add_argument('--w-bits',
type=int,
default=4,
@@ -99,6 +101,7 @@ def add_parser_calibrate():
ArgumentHelper.calib_seqlen(parser)
ArgumentHelper.calib_batchsize(parser)
ArgumentHelper.calib_search_scale(parser)
+ ArgumentHelper.dtype(parser)
@staticmethod
def add_parser_smooth_quant():
@@ -122,6 +125,7 @@ def add_parser_smooth_quant():
ArgumentHelper.calib_seqlen(parser)
ArgumentHelper.calib_batchsize(parser)
ArgumentHelper.calib_search_scale(parser)
+ ArgumentHelper.dtype(parser)
@staticmethod
def auto_awq(args):
diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py
index 68f9de8c15..5cf3453b7e 100644
--- a/lmdeploy/cli/serve.py
+++ b/lmdeploy/cli/serve.py
@@ -238,6 +238,7 @@ def add_parser_proxy():
help='the strategy to dispatch requests to nodes')
ArgumentHelper.api_keys(parser)
ArgumentHelper.ssl(parser)
+ ArgumentHelper.log_level(parser)
@staticmethod
def gradio(args):
diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py
index 85784a58f5..cf7b6526ec 100644
--- a/lmdeploy/cli/utils.py
+++ b/lmdeploy/cli/utils.py
@@ -354,7 +354,7 @@ def calib_batchsize(parser):
@staticmethod
def calib_search_scale(parser):
- """Add argument batch_size to parser."""
+ """Add argument search_scale to parser."""
return parser.add_argument(
'--search-scale',
diff --git a/lmdeploy/lite/apis/auto_awq.py b/lmdeploy/lite/apis/auto_awq.py
index c41b28fd6e..2c84612839 100644
--- a/lmdeploy/lite/apis/auto_awq.py
+++ b/lmdeploy/lite/apis/auto_awq.py
@@ -2,6 +2,7 @@
import os
import os.path as osp
import shutil
+from typing import Literal
import torch
from torch import nn
@@ -12,9 +13,7 @@
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.pytorch.check_env import try_import_deeplink
-from .calibrate import LAYER_TYPE_MAP, NORM_TYPE_MAP, calibrate
-
-NORM_TYPE_MAP = NORM_TYPE_MAP # legacy
+from .calibrate import LAYER_TYPE_MAP, calibrate
def save_vl_model(vl_model, model_path, dst_path):
@@ -56,6 +55,7 @@ def auto_awq(model: str,
search_scale: bool = False,
device: str = 'cuda',
revision: str = None,
+ dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',
download_dir: str = None):
"""Perform weight quantization using AWQ algorithm.
@@ -77,6 +77,7 @@ def auto_awq(model: str,
revision (str): The specific model version to use. It can be a
branch name, a tag name, or a commit id. If unspecified,
will use the default version.
+ dtype (str): Data type for loading model weights and calib infer.
download_dir (str): Directory to download and load the weights,
default to the default cache directory of huggingface.
"""
@@ -96,6 +97,7 @@ def auto_awq(model: str,
w_bits=w_bits,
w_group_size=w_group_size,
search_scale=search_scale,
+ dtype=dtype,
batch_size=batch_size)
layer_type = LAYER_TYPE_MAP[type(model).__name__]
diff --git a/lmdeploy/lite/apis/calibrate.py b/lmdeploy/lite/apis/calibrate.py
index 71f7a5900c..007f831a70 100644
--- a/lmdeploy/lite/apis/calibrate.py
+++ b/lmdeploy/lite/apis/calibrate.py
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
-from typing import Union
+from typing import Literal, Union
import torch
from torch import nn
@@ -11,6 +11,7 @@
from lmdeploy.lite.quantization import CalibrationContext, CalibrationContextV2
from lmdeploy.lite.utils import (collect_target_modules, get_calib_loaders,
load_hf_from_pretrained)
+from lmdeploy.vl.model.builder import load_vl_model
LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
@@ -204,6 +205,7 @@ def calibrate(model: str,
w_bits: int = 4,
w_group_size: int = 128,
search_scale: bool = False,
+ dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',
batch_size: int = 1) -> None:
"""The main function for loading the model and performing calibration on a
given dataset.
@@ -224,6 +226,7 @@ def calibrate(model: str,
w_group_size (int): Group size for weight quantization statistics.
search_scale (bool): Whether search scale ratio. Default to False,
which means only smooth quant with 0.5 ratio will be applied.
+ dtype (str): Data type for loading model weights and calib infer.
batch_size (int): The batch size for running the calib samples.
Low GPU mem requires small batch_size. Large batch_size
reduces the calibration time while costs more VRAM.
@@ -239,20 +242,35 @@ def calibrate(model: str,
model_type, _ = get_task(model)
make_compatible_internvl_config(model)
- if model_type == 'llm':
- # Load tokenizer and configuration
- tokenizer = AutoTokenizer.from_pretrained(model,
- trust_remote_code=True)
+ # Load tokenizer and configuration
+ tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
+
+ if model_type == 'llm':
model = load_hf_from_pretrained(model,
- torch_dtype=torch.float16,
+ dtype=dtype,
trust_remote_code=True)
vl_model = None
elif model_type == 'vlm':
- from lmdeploy.vl.model.builder import vl_model_with_tokenizer
- vl_model, model, tokenizer = vl_model_with_tokenizer(model_path=model)
+ vl_model = load_vl_model(model, backend=None, with_llm=True).vl_model
+ model = vl_model
+ if hasattr(vl_model, 'language_model'): # deepseek-vl, ...
+ model = vl_model.language_model
+ if hasattr(vl_model, 'llm'): # MiniCPMV, ...
+ model = vl_model.llm
+ model.config.use_cache = False
+ if dtype == 'float16':
+ model.half()
+ elif dtype == 'bfloat16':
+ assert torch.cuda.is_bf16_supported(
+ ), 'your device does not support bfloat16 please set --dtype float16' # noqa
+ model.to(torch.bfloat16)
+ elif dtype == 'auto' and model.config.torch_dtype == torch.bfloat16:
+ print('Warning: we cast model to float16 to prevent OOM. You'
+ ' may enforce it bfloat16 by `--dtype bfloat16`')
+ model.half()
+ model.eval()
- model.config.use_cache = False
model_type = type(model).__name__
if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:
raise RuntimeError(
diff --git a/lmdeploy/lite/apis/gptq.py b/lmdeploy/lite/apis/gptq.py
index 12b88a52cd..eb4418a533 100644
--- a/lmdeploy/lite/apis/gptq.py
+++ b/lmdeploy/lite/apis/gptq.py
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
+from typing import Literal
import torch
-from transformers import AutoTokenizer
+from transformers import AutoConfig, AutoTokenizer
from lmdeploy.lite.utils.calib_dataloader import get_calib_loaders
@@ -15,6 +16,7 @@ def auto_gptq(model: str,
calib_samples: int = 128,
calib_seqlen: int = 2048,
batch_size: int = 1,
+ dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',
revision: str = None):
"""Perform weight quantization using AWQ algorithm.
@@ -29,9 +31,7 @@ def auto_gptq(model: str,
calib_seqlen (int): The sequence length for calibration.
w_bits (int): Bit number for weight quantization.
w_group_size (int): Group size for weight quantization statistics.
- search_scale (bool): Whether search scale ratio. Default to False,
- which means only smooth quant with 0.5 ratio will be applied.
- device (str): Device type of running.
+ dtype (str): Data type for loading model weights and calib infer.
revision (str): The specific model version to use. It can be a
branch name, a tag name, or a commit id. If unspecified,
will use the default version.
@@ -83,9 +83,18 @@ def auto_gptq(model: str,
# load un-quantized model, by default,
# the model will always be loaded into CPU memory
+ hf_config = AutoConfig.from_pretrained(pretrained_model_dir,
+ revision=revision,
+ trust_remote_code=True)
+ torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16)
+ if dtype == 'float16':
+ torch_dtype = torch.float16
+ elif dtype == 'bfloat16':
+ torch_dtype = torch.bfloat16
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir,
quantize_config,
revision=revision,
+ torch_dtype=torch_dtype,
trust_remote_code=True)
# quantize model, the examples should be list of dict whose keys
diff --git a/lmdeploy/lite/apis/smooth_quant.py b/lmdeploy/lite/apis/smooth_quant.py
index c8df67355e..188eedbd0e 100644
--- a/lmdeploy/lite/apis/smooth_quant.py
+++ b/lmdeploy/lite/apis/smooth_quant.py
@@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
+
+from typing import Literal
+
import fire
import torch
from torch import nn
@@ -6,7 +9,8 @@
from lmdeploy.lite.apis.calibrate import (LAYER_TYPE_MAP, NORM_TYPE_MAP,
calibrate)
from lmdeploy.lite.quantization.awq import (FC_FCS_MAP, NORM_FCS_MAP,
- awq_layers, smooth_layers)
+ awq_layers, skipped_module,
+ smooth_layers)
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.pytorch.models import QLinear, QRMSNorm
@@ -19,8 +23,8 @@ def smooth_quant(model: str,
search_scale: bool = False,
batch_size: int = 1,
w_bits: int = 8,
+ dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',
device: str = 'cuda'):
-
model_path = model
vl_model, model, tokenizer, work_dir = calibrate(model,
calib_dataset,
@@ -31,6 +35,7 @@ def smooth_quant(model: str,
w_bits=w_bits,
w_group_size=-1,
search_scale=search_scale,
+ dtype=dtype,
batch_size=batch_size)
# calibrate function exports the calibration statistics
@@ -76,6 +81,8 @@ def smooth_quant(model: str,
rmsnorms = collect_target_modules(model, norm_type)
for name, linear in fcs.items():
+ if skipped_module(name):
+ continue
linear.to(device)
q_linear = QLinear.from_float(linear)
parent_name, _, child_name = name.rpartition('.')
@@ -84,6 +91,8 @@ def smooth_quant(model: str,
linear.to('cpu')
for name, norm in rmsnorms.items():
+ if skipped_module(name):
+ continue
norm.to(device)
q_norm = QRMSNorm.from_float(norm)
parent_name, _, child_name = name.rpartition('.')
diff --git a/lmdeploy/lite/quantization/awq.py b/lmdeploy/lite/quantization/awq.py
index cf03a75216..3e24a13cc3 100644
--- a/lmdeploy/lite/quantization/awq.py
+++ b/lmdeploy/lite/quantization/awq.py
@@ -43,8 +43,10 @@
'MixtralDecoderLayer': {
'input_layernorm':
['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'],
- 'post_attention_layernorm':
- ['block_sparse_moe.experts.{i}.w1', 'block_sparse_moe.experts.{i}.w3']
+ 'post_attention_layernorm': [
+ 'block_sparse_moe.gate', 'block_sparse_moe.experts.{i}.w1',
+ 'block_sparse_moe.experts.{i}.w3'
+ ]
},
'Qwen2VLDecoderLayer': {
'input_layernorm':
@@ -120,7 +122,12 @@ def get_weight_scale(weight, q_group_size=-1):
org_shape = weight.shape
if q_group_size > 0:
weight = weight.view(-1, q_group_size)
- scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
+ abs_weight = weight.abs()
+ abs_weight_amax = abs_weight.amax(dim=1, keepdim=True)
+ if abs_weight_amax.min().item() == 0:
+ print('weight.amax.min is zero, clamping weight.amax to 1e-4')
+ abs_weight_amax = abs_weight_amax.clamp(min=1e-4)
+ scale = abs_weight / abs_weight_amax
scale = scale.view(org_shape)
scale = scale.mean(0)
return scale
@@ -153,8 +160,13 @@ def smooth_ln_fcs(ln: torch.nn.Module,
concat_w = torch.cat([fc.weight for fc in fcs], dim=0)
w_scales = get_weight_scale(concat_w, group_size)
+ w_scales_pow = w_scales.pow(1 - alpha)
+ if w_scales_pow.min().item() == 0:
+ print('w_scales.pow(1 - alpha).min is zero, '
+ 'clamping w_scales.pow(1 - alpha) to 1e-4')
+ w_scales_pow = w_scales_pow.clamp(min=1e-4)
scales = (act_scales.pow(alpha) /
- w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype)
+ w_scales_pow).clamp(min=1e-4).to(device).to(dtype)
scales = scales / (scales[nonzero_positions].max() *
scales[nonzero_positions].min()).sqrt()
@@ -204,8 +216,13 @@ def smooth_fc_fcs(pre_fc: torch.nn.Module,
concat_w = torch.cat([fc.weight for fc in fcs], dim=0)
w_scales = get_weight_scale(concat_w, group_size)
+ w_scales_pow = w_scales.pow(1 - alpha)
+ if w_scales_pow.min().item() == 0:
+ print('w_scales.pow(1 - alpha).min is zero, '
+ 'clamping w_scales.pow(1 - alpha) to 1e-4')
+ w_scales_pow = w_scales_pow.clamp(min=1e-4)
scales = (act_scales.pow(alpha) /
- w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype)
+ w_scales_pow).clamp(min=1e-4).to(device).to(dtype)
scales = scales / (scales.max() * scales.min()).sqrt()
# (for qwen&baichuan) pre_fc is packed QKV, only V needs to scale
diff --git a/lmdeploy/lite/quantization/calibration.py b/lmdeploy/lite/quantization/calibration.py
index e590f1a4eb..1df8f2c740 100644
--- a/lmdeploy/lite/quantization/calibration.py
+++ b/lmdeploy/lite/quantization/calibration.py
@@ -42,6 +42,9 @@ def __init__(self,
tokenizer (PreTrainedTokenizer): Tokenizer of the given model.
layer_type (Union[str, type]): Type of the layers to be observed.
norm_type (Union[str, type]): Norm type used in the model.
+ batch_size (int): The batch size for running the calib samples.
+ Low GPU mem requires small batch_size. Large batch_size
+ reduces the calibration time while costs more VRAM.
device (str, optional): Device where the model should run.
Defaults to 'cuda'.
"""
@@ -290,9 +293,14 @@ def _search_module_scale(block, linears2scale: list, x, kwargs={}):
org_sd = {k: v.cpu() for k, v in block.state_dict().items()}
for ratio in range(0, n_grid):
- ratio = ratio * 1 / n_grid
- scales = (x_max.pow(ratio) /
- w_mean.pow(1 - ratio)).clamp(min=1e-4).view(-1)
+ ratio = ratio / n_grid
+ w_mean_pow = w_mean.pow(1 - ratio)
+ if w_mean_pow.min().item() == 0:
+ print('w_mean.pow(1 - ratio).min is zero, '
+ 'clamping w_mean.pow(1 - ratio) to 1e-4')
+ w_mean_pow = w_mean_pow.clamp(min=1e-4)
+ scales = (x_max.pow(ratio) / w_mean_pow).clamp(min=1e-4).view(-1)
+
scales = scales / (scales.max() * scales.min()).sqrt()
for fc in linears2scale:
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
diff --git a/lmdeploy/lite/utils/load.py b/lmdeploy/lite/utils/load.py
index bfd306a743..ac4519371a 100644
--- a/lmdeploy/lite/utils/load.py
+++ b/lmdeploy/lite/utils/load.py
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Literal
+
import torch
from transformers import AutoConfig, AutoModelForCausalLM
@@ -7,29 +9,42 @@
def load_hf_from_pretrained(pretrained_model_name_or_path,
- dtype=torch.float16,
- **kwargs):
+ dtype: Literal['float16', 'bfloat16',
+ 'auto'], **kwargs):
- if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
+ if dtype == 'bfloat16' and not torch.cuda.is_bf16_supported():
raise RuntimeError('Your device does not supports bf16(bfloat16), '
'please change to fp16(float16)')
kwargs.pop('config', None)
hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path,
- torch_dtype=dtype,
trust_remote_code=True)
# HACK hard code for qwen, other configs do not have the `fp16` attribute.
- if dtype == torch.float16:
- hf_config.fp16 = True
- elif dtype == torch.bfloat16:
- hf_config.bf16 = True
+ if hasattr(hf_config, 'fp16') or hasattr(hf_config, 'bf16'):
+ if dtype == 'bfloat16':
+ hf_config.bf16 = True
+ else:
+ hf_config.fp16 = True
+
+ torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16)
+ if dtype == 'bfloat16':
+ torch_dtype = torch.bfloat16
+ elif dtype == 'float16':
+ torch_dtype = torch.float16
+ elif dtype == 'auto' and torch_dtype == torch.bfloat16:
+ print('Warning: we cast model to float16 to prevent OOM. '
+ 'You may enforce it bfloat16 by `--dtype bfloat16`')
+ torch_dtype = torch.float16
with LoadNoInit():
# Load model
model = AutoModelForCausalLM.from_pretrained(
- pretrained_model_name_or_path, config=hf_config, **kwargs)
+ pretrained_model_name_or_path,
+ config=hf_config,
+ torch_dtype=torch_dtype,
+ **kwargs)
model.config.use_cache = False
return model
diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py
index 4f04906f12..b54aa95330 100644
--- a/lmdeploy/messages.py
+++ b/lmdeploy/messages.py
@@ -293,8 +293,11 @@ def __post_init__(self):
assert self.device_type in [
'cuda', 'ascend', 'maca'
], (f'invalid device_type: {self.device_type}')
- if self.quant_policy > 0 and self.device_type != 'cuda':
- assert False, 'kv cache quantization only works for CUDA.'
+ if self.quant_policy > 0 and self.device_type not in [
+ 'cuda', 'ascend'
+ ]:
+ assert False, \
+ 'kv cache quantization only works for CUDA and ASCEND.'
class ResponseType(enum.Enum):
diff --git a/lmdeploy/model.py b/lmdeploy/model.py
index a4355ea131..a0b0c8e09b 100644
--- a/lmdeploy/model.py
+++ b/lmdeploy/model.py
@@ -847,7 +847,7 @@ def __init__(
- Only call one function at a time
- Put the entire function call reply on one line"
- Always add your sources when using search results to answer the user query\n\n""", # noqa
- knowledge='Cutting Knowledge Date: December 2023\nToday Date: 23 Jul 2024\n\n',
+ knowledge='Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n',
meta_instruction='You are a helpful assistant.',
ipython='<|start_header_id|>ipython<|end_header_id|>\n\n',
eoi='<|eot_id|>',
@@ -1921,5 +1921,5 @@ def best_match_model(query: str) -> Optional[str]:
for name, model in MODELS.module_dict.items():
if model.match(query):
return model.match(query)
- logger.warn(f'Did not find a chat template matching {query}.')
+ logger.warning(f'Did not find a chat template matching {query}.')
return 'base'
diff --git a/lmdeploy/pytorch/backends/attention.py b/lmdeploy/pytorch/backends/attention.py
index 92a0befbf4..f0e60d86ac 100644
--- a/lmdeploy/pytorch/backends/attention.py
+++ b/lmdeploy/pytorch/backends/attention.py
@@ -34,6 +34,7 @@ def __init__(
alibi: bool = None,
sliding_window: int = None,
logit_softcapping: float = None,
+ causal: bool = True,
**kwargs,
) -> None:
if scale is None:
@@ -53,6 +54,7 @@ def __init__(
self.alibi = alibi
self.sliding_window = sliding_window
self.logit_softcapping = logit_softcapping
+ self.causal = causal
@abstractmethod
def forward(
@@ -82,6 +84,7 @@ def build(
alibi: bool = False,
sliding_window: int = None,
logical_softcapping: float = None,
+ causal: bool = True,
**kwargs,
) -> AttentionImpl[T]:
"""build."""
diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py
index ef538f7a3d..c8623666dc 100644
--- a/lmdeploy/pytorch/backends/base.py
+++ b/lmdeploy/pytorch/backends/base.py
@@ -12,7 +12,8 @@
class OpType(Enum):
"""Layer type enumerate."""
- Attention = auto()
+ PagedAttention = auto()
+ FlashAttention = auto()
Linear = auto()
RotaryEmbedding = auto()
ApplyRotaryEmb = auto()
diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py
index 8261b869f0..f9227497f2 100644
--- a/lmdeploy/pytorch/backends/cuda/attention.py
+++ b/lmdeploy/pytorch/backends/cuda/attention.py
@@ -41,6 +41,7 @@ def __init__(
alibi: bool = False,
sliding_window: int = None,
logit_softcapping: float = None,
+ causal: bool = True,
**kwargs,
):
super().__init__(
@@ -52,8 +53,10 @@ def __init__(
alibi=alibi,
sliding_window=sliding_window,
logit_softcapping=logit_softcapping,
+ causal=causal,
**kwargs,
)
+ assert not (alibi and not causal)
from lmdeploy.pytorch.kernels.cuda import (alibi_paged_attention_fwd,
fill_kv_cache,
@@ -172,6 +175,7 @@ def forward(
window_size=self.sliding_window,
sm_scale=self.scale,
logit_softcapping=self.logit_softcapping,
+ causal=self.causal,
)
else:
self.alibi_paged_attention_fwd(
@@ -207,6 +211,7 @@ def build(
alibi: bool = False,
sliding_window: int = None,
logical_softcapping: float = None,
+ causal: bool = True,
**kwargs,
) -> TritonAttentionImpl:
"""build."""
@@ -218,4 +223,5 @@ def build(
alibi=alibi,
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
+ causal=causal,
**kwargs)
diff --git a/lmdeploy/pytorch/backends/cuda/flash_attention.py b/lmdeploy/pytorch/backends/cuda/flash_attention.py
new file mode 100644
index 0000000000..5d3925b744
--- /dev/null
+++ b/lmdeploy/pytorch/backends/cuda/flash_attention.py
@@ -0,0 +1,101 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import Tensor
+
+from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl
+
+
+class TritonFlashAttentionImpl(FlashAttentionImpl):
+ """triton flash attention implementation."""
+
+ def __init__(
+ self,
+ num_heads: int,
+ head_dim: int,
+ scale: float = None,
+ num_kv_heads: int = None,
+ v_head_dim: int = None,
+ causal: bool = True,
+ sliding_window: int = None,
+ logical_softcapping: float = None,
+ ):
+ if scale is None:
+ scale = 1.0 / (head_dim**0.5)
+
+ if num_kv_heads is None:
+ num_kv_heads = num_heads
+
+ if v_head_dim is None:
+ v_head_dim = head_dim
+
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.scale = scale
+ self.num_kv_heads = num_kv_heads
+ self.v_head_dim = v_head_dim
+ self.causal = causal
+ self.sliding_window = sliding_window
+ self.logical_softcapping = logical_softcapping
+
+ from lmdeploy.pytorch.kernels.cuda import flash_attention_fwd
+ self.flash_attention_fwd = flash_attention_fwd
+
+ def forward(self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ q_start_loc: Tensor,
+ q_seqlens: Tensor,
+ kv_start_loc: Tensor,
+ kv_seqlens: Tensor,
+ max_q_seqlen: int = None):
+ """forward."""
+
+ q_shape = query.shape
+ o_shape = q_shape[:-1] + (self.v_head_dim, )
+ out = query.new_empty(o_shape)
+ self.flash_attention_fwd(
+ query,
+ key,
+ value,
+ out,
+ q_start_loc=q_start_loc,
+ q_seqlens=q_seqlens,
+ kv_start_loc=kv_start_loc,
+ kv_seqlens=kv_seqlens,
+ max_seqlen=max_q_seqlen,
+ window_size=self.sliding_window,
+ sm_scale=self.scale,
+ logit_softcapping=self.logical_softcapping,
+ causal=self.causal,
+ kv_layout='shd',
+ )
+
+ return out
+
+
+class TritonFlashAttentionBuilder(FlashAttentionBuilder):
+ """triton attention builder."""
+
+ @staticmethod
+ def build(
+ num_heads: int,
+ head_dim: int,
+ scale: float = None,
+ num_kv_heads: int = None,
+ v_head_dim: int = None,
+ causal: bool = True,
+ sliding_window: int = None,
+ logical_softcapping: float = None,
+ **kwargs,
+ ) -> FlashAttentionImpl:
+ """build."""
+ return TritonFlashAttentionImpl(
+ num_heads=num_heads,
+ head_dim=head_dim,
+ scale=scale,
+ num_kv_heads=num_kv_heads,
+ v_head_dim=v_head_dim,
+ causal=causal,
+ sliding_window=sliding_window,
+ logical_softcapping=logical_softcapping,
+ )
diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py
index d796f8e19f..bfe89dc63d 100644
--- a/lmdeploy/pytorch/backends/cuda/op_backend.py
+++ b/lmdeploy/pytorch/backends/cuda/op_backend.py
@@ -23,9 +23,12 @@ def get_name() -> str:
@classmethod
def get_layer_impl_builder(cls, layer_type: OpType):
"""get cuda layer builder."""
- if layer_type == OpType.Attention:
+ if layer_type == OpType.PagedAttention:
from .attention import TritonAttentionBuilder
return TritonAttentionBuilder
+ elif layer_type == OpType.FlashAttention:
+ from .flash_attention import TritonFlashAttentionBuilder
+ return TritonFlashAttentionBuilder
elif layer_type == OpType.ApplyRotaryEmb:
from .apply_rotary_emb import TritonApplyRotaryEmbBuilder
return TritonApplyRotaryEmbBuilder
@@ -125,30 +128,30 @@ def update_step_context(cls, step_context):
quant_policy=step_context.kv_quant_policy,
)
- cross_attn_metadata = None
- fill_seqlens = None
- if step_context.cross_attention_states is not None:
- fill_seqlens = torch.zeros_like(q_seqlens)
- for idx, state in enumerate(step_context.cross_attention_states):
- if state is not None:
- fill_seqlens[idx] = state.shape[-2]
+ cross_seqlens = step_context.cross_seqlens
cross_kv_seqlens = step_context.cross_kv_seqlens
- cross_kv_start_loc = None
- cross_kv_flatten_size = None
- if not step_context.is_decoding and cross_kv_seqlens is not None:
- cross_kv_start_loc = cross_kv_seqlens.cumsum(0) - cross_kv_seqlens
- cross_kv_flatten_size = cross_kv_seqlens.sum().item()
- cross_attn_metadata = attn_meta_cls(
- step_context.is_decoding,
- step_context.block_offsets,
- q_start_loc=q_start_loc,
- q_seqlens=q_seqlens,
- kv_start_loc=cross_kv_start_loc,
- kv_seqlens=cross_kv_seqlens,
- kv_flatten_size=cross_kv_flatten_size,
- fill_seqlens=fill_seqlens,
- quant_policy=step_context.kv_quant_policy,
- )
+ cross_attn_metadata = None
+ if cross_seqlens is not None:
+ fill_seqlens = cross_seqlens
+ if fill_seqlens.sum().item() == 0:
+ fill_seqlens = None
+ cross_kv_start_loc = None
+ cross_kv_flatten_size = None
+ if not step_context.is_decoding and cross_kv_seqlens is not None:
+ cross_kv_start_loc = cross_kv_seqlens.cumsum(
+ 0) - cross_kv_seqlens
+ cross_kv_flatten_size = cross_kv_seqlens.sum().item()
+ cross_attn_metadata = attn_meta_cls(
+ step_context.is_decoding,
+ step_context.block_offsets,
+ q_start_loc=q_start_loc,
+ q_seqlens=q_seqlens,
+ kv_start_loc=cross_kv_start_loc,
+ kv_seqlens=cross_kv_seqlens,
+ kv_flatten_size=cross_kv_flatten_size,
+ fill_seqlens=fill_seqlens,
+ quant_policy=step_context.kv_quant_policy,
+ )
step_context.attn_metadata = attn_metadata
step_context.cross_attn_metadata = cross_attn_metadata
diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py
index f9664f13ff..e3c5dc4d5e 100644
--- a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py
+++ b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py
@@ -33,10 +33,17 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig,
dlinfer.graph.config.enable_graph_mode = True
self.patch_kernels_custom_op()
self.patch_kvcache_static_shape()
- self.model = torch.compile(self.model,
- fullgraph=True,
- dynamic=True,
- backend='atbgraph')
+ if hasattr(self.model, 'language_model'):
+ self.model.language_model = torch.compile(
+ self.model.language_model,
+ fullgraph=True,
+ dynamic=True,
+ backend='atbgraph')
+ else:
+ self.model = torch.compile(self.model,
+ fullgraph=True,
+ dynamic=True,
+ backend='atbgraph')
def check_enable_graph(self):
"""check enable graph."""
diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
index b6f544510b..588558f0d5 100644
--- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
+++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
@@ -1,5 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Tuple
+import itertools
+import os
+import re
+from pathlib import Path
+from typing import Dict, Tuple
import torch
@@ -11,6 +15,71 @@
logger = get_logger('lmdeploy')
+class AscendKVQuantMeta:
+ has_set_value: bool = False
+ quant_meta: Dict = {}
+
+ @classmethod
+ def set_value(cls, device: str, dtype: torch.dtype, record_file: str,
+ total_layers: int):
+ with open(record_file, 'r') as file:
+ data = file.read()
+ scale_offset_pairs = re.findall(
+ r'scale:\s*([\d\.\-]+)\s*offset:\s*(-?\d+)', data)
+ scale_offset_pairs = [(float(scale), float(offset))
+ for scale, offset in scale_offset_pairs]
+ k_scales, v_scales, kv_scales = [], [], []
+ k_zeros, v_zeros, kv_zeros = [], [], []
+ if len(scale_offset_pairs) == total_layers:
+ for scale, offset in scale_offset_pairs:
+ k_scales.append(
+ torch.tensor([scale], device=device, dtype=dtype))
+ v_scales.append(
+ torch.tensor([scale], device=device, dtype=dtype))
+ kv_scales.append(
+ torch.tensor([scale, scale], device=device, dtype=dtype))
+ k_zeros.append(
+ torch.tensor([offset], device=device, dtype=dtype))
+ v_zeros.append(
+ torch.tensor([offset], device=device, dtype=dtype))
+ kv_zeros.append(
+ torch.tensor([offset, offset], device=device, dtype=dtype))
+ elif len(scale_offset_pairs) == total_layers * 2:
+ for i in range(total_layers):
+ scale_k, offset_k = scale_offset_pairs[2 * i]
+ scale_v, offset_v = scale_offset_pairs[2 * i + 1]
+ k_scales.append(
+ torch.tensor([scale_k], device=device, dtype=dtype))
+ v_scales.append(
+ torch.tensor([scale_v], device=device, dtype=dtype))
+ kv_scales.append(
+ torch.tensor([scale_k, scale_v],
+ device=device,
+ dtype=dtype))
+ k_zeros.append(
+ torch.tensor([offset_k], device=device, dtype=dtype))
+ v_zeros.append(
+ torch.tensor([offset_v], device=device, dtype=dtype))
+ kv_zeros.append(
+ torch.tensor([offset_k, offset_v],
+ device=device,
+ dtype=dtype))
+ else:
+ raise ValueError(
+ f'num of scale_offset_pairs({len(scale_offset_pairs)}) '
+ f'must match num of total_layers({total_layers})')
+
+ cls.quant_meta.update({
+ 'k_scales': itertools.cycle(k_scales),
+ 'k_zeros': itertools.cycle(k_zeros),
+ 'v_scales': itertools.cycle(v_scales),
+ 'v_zeros': itertools.cycle(v_zeros),
+ 'kv_scales': itertools.cycle(kv_scales),
+ 'kv_zeros': itertools.cycle(kv_zeros)
+ })
+ cls.has_set_value = True
+
+
class AscendOpsBackend(DlinferOpsBackend):
"""ascend layer backend."""
enable_graph = False
@@ -164,6 +233,21 @@ def get_total_slots():
.repeat_interleave(step_context.q_seqlens, 0)
kv_seqlens = kv_seqlens_cpu
+ if not cls.enable_graph and step_context.kv_quant_policy == 8:
+ record_file = os.getenv('ASCEND_QUANT_RECORD_FILE')
+ assert record_file, 'please specify valid ASCEND_QUANT_RECORD_FILE'
+ path = Path(record_file)
+ is_path = path.is_absolute() or path.is_relative_to('/')
+ exists = path.exists()
+ if not (is_path and exists):
+ raise ValueError(
+ 'please specify valid ASCEND_QUANT_RECORD_FILE')
+ if not AscendKVQuantMeta.has_set_value:
+ total_layers = len(step_context.kv_caches)
+ AscendKVQuantMeta.set_value(step_context.block_offsets.device,
+ step_context.model_config.dtype,
+ record_file, total_layers)
+
attn_meta_cls = cls.get_attention_metadata_cls()
attn_metadata = attn_meta_cls(
step_context.is_decoding,
@@ -177,6 +261,8 @@ def get_total_slots():
is_unpaged_prefill=is_unpaged_prefill,
max_q_seq_len=max_q_seq_len,
max_kv_seq_len=max_kv_seq_len,
+ quant_policy=step_context.kv_quant_policy,
+ quant_meta=AscendKVQuantMeta.quant_meta,
)
step_context.attn_metadata = attn_metadata
diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py
index 0d666c9130..6b03403c84 100644
--- a/lmdeploy/pytorch/backends/dlinfer/attention.py
+++ b/lmdeploy/pytorch/backends/dlinfer/attention.py
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
-from typing import Optional, Sequence
+from typing import Dict, Optional, Sequence
from torch import Tensor
@@ -15,6 +15,7 @@ class DlinferAttentionMetadata(AttentionMetadata):
is_unpaged_prefill: Optional[bool] = None
max_q_seq_len: int = 1
max_kv_seq_len: int = 1
+ quant_meta: Dict = None
class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
@@ -30,8 +31,10 @@ def __init__(
alibi: bool = None,
sliding_window: int = None,
logit_softcapping: float = None,
+ causal: bool = True,
**kwargs,
):
+ assert causal
super().__init__(
num_heads,
head_size,
@@ -41,6 +44,7 @@ def __init__(
alibi,
sliding_window,
logit_softcapping,
+ causal=causal,
**kwargs,
)
@@ -74,10 +78,37 @@ def forward(
is_unpaged_prefill = attn_metadata.is_unpaged_prefill
max_q_seq_len = attn_metadata.max_q_seq_len
max_kv_seq_len = attn_metadata.max_kv_seq_len
+ quant_bits = attn_metadata.quant_policy
+ if attn_metadata.quant_meta is not None:
+ k_scales_zeros = [
+ next(attn_metadata.quant_meta['k_scales']),
+ next(attn_metadata.quant_meta['k_zeros'])
+ ] if 'k_scales' in attn_metadata.quant_meta else []
+ v_scales_zeros = [
+ next(attn_metadata.quant_meta['v_scales']),
+ next(attn_metadata.quant_meta['v_zeros'])
+ ] if 'v_scales' in attn_metadata.quant_meta else []
+ kv_scales = next(
+ attn_metadata.quant_meta['kv_scales']
+ ) if 'kv_scales' in attn_metadata.quant_meta else None
+ kv_zeros = next(
+ attn_metadata.quant_meta['kv_zeros']
+ ) if 'kv_zeros' in attn_metadata.quant_meta else None
+ else:
+ k_scales_zeros = []
+ v_scales_zeros = []
+ kv_scales = None
+ kv_zeros = None
# fill kv cache
- k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache,
- kv_start_indices)
+ k_cache, v_cache = self.fill_kv_cache(key,
+ value,
+ k_cache,
+ v_cache,
+ kv_start_indices,
+ k_scales_zeros=k_scales_zeros,
+ v_scales_zeros=v_scales_zeros,
+ quant_bits=quant_bits)
if inplace:
attn_output = query[..., :self.v_head_size]
@@ -103,6 +134,9 @@ def forward(
block_size=block_size,
attn_mask=attn_mask,
is_unpaged_prefill=is_unpaged_prefill,
+ kv_scales=kv_scales,
+ kv_zeros=kv_zeros,
+ quant_bits=quant_bits,
)
return attn_output
@@ -121,6 +155,7 @@ def build(
alibi_scale: float = None,
sliding_window: int = None,
logical_softcapping: float = None,
+ causal: bool = True,
**kwargs,
) -> DlinferAttentionImpl:
"""build."""
@@ -132,4 +167,5 @@ def build(
alibi_scale=alibi_scale,
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
+ causal=causal,
**kwargs)
diff --git a/lmdeploy/pytorch/backends/dlinfer/flash_attention.py b/lmdeploy/pytorch/backends/dlinfer/flash_attention.py
new file mode 100644
index 0000000000..d0d9ddbb26
--- /dev/null
+++ b/lmdeploy/pytorch/backends/dlinfer/flash_attention.py
@@ -0,0 +1,94 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import Tensor
+
+from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl
+
+
+class DlinferFlashAttentionImpl(FlashAttentionImpl):
+ """dlinfer flash attention implementation."""
+
+ def __init__(
+ self,
+ num_heads: int,
+ head_dim: int,
+ scale: float = None,
+ num_kv_heads: int = None,
+ v_head_dim: int = None,
+ causal: bool = True,
+ sliding_window: int = None,
+ logical_softcapping: float = None,
+ ):
+ if scale is None:
+ scale = 1.0 / (head_dim**0.5)
+ if num_kv_heads is None:
+ num_kv_heads = num_heads
+ if v_head_dim is None:
+ v_head_dim = head_dim
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.scale = scale
+ self.num_kv_heads = num_kv_heads
+ self.v_head_dim = v_head_dim
+ self.causal = causal
+ self.sliding_window = sliding_window
+ self.logical_softcapping = logical_softcapping
+ from lmdeploy.pytorch.kernels.dlinfer import flash_attention_fwd
+ self.flash_attention_fwd = flash_attention_fwd
+
+ def forward(self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ q_start_loc: Tensor,
+ q_seqlens: Tensor,
+ kv_start_loc: Tensor,
+ kv_seqlens: Tensor,
+ max_q_seqlen: int = None):
+ """forward."""
+ q_shape = query.shape
+ o_shape = q_shape[:-1] + (self.v_head_dim, )
+ out = query.new_empty(o_shape)
+ self.flash_attention_fwd(
+ query,
+ key,
+ value,
+ out,
+ q_start_loc=q_start_loc,
+ q_seqlens=q_seqlens,
+ kv_start_loc=kv_start_loc,
+ kv_seqlens=kv_seqlens,
+ max_q_seqlen=max_q_seqlen,
+ window_size=self.sliding_window,
+ sm_scale=self.scale,
+ logit_softcapping=self.logical_softcapping,
+ causal=self.causal,
+ )
+ return out
+
+
+class DlinferFlashAttentionBuilder(FlashAttentionBuilder):
+ """dlinfer attention builder."""
+
+ @staticmethod
+ def build(
+ num_heads: int,
+ head_dim: int,
+ scale: float = None,
+ num_kv_heads: int = None,
+ v_head_dim: int = None,
+ causal: bool = True,
+ sliding_window: int = None,
+ logical_softcapping: float = None,
+ **kwargs,
+ ) -> FlashAttentionImpl:
+ """build."""
+ return DlinferFlashAttentionImpl(
+ num_heads=num_heads,
+ head_dim=head_dim,
+ scale=scale,
+ num_kv_heads=num_kv_heads,
+ v_head_dim=v_head_dim,
+ causal=causal,
+ sliding_window=sliding_window,
+ logical_softcapping=logical_softcapping,
+ )
diff --git a/lmdeploy/pytorch/backends/dlinfer/moe.py b/lmdeploy/pytorch/backends/dlinfer/moe.py
index 6ada730fbe..ff986c5765 100644
--- a/lmdeploy/pytorch/backends/dlinfer/moe.py
+++ b/lmdeploy/pytorch/backends/dlinfer/moe.py
@@ -47,8 +47,8 @@ def forward(self,
down_weights: torch.Tensor,
expert_list: List[int] = None):
"""forward."""
- return fused_moe(hidden_states, self.top_k, topk_ids, topk_weights,
- gate_up_weights, down_weights)
+ return fused_moe(hidden_states, gate_up_weights, down_weights,
+ topk_weights, topk_ids, self.top_k, self.renormalize)
class DlinferFusedMoEBuilder(FusedMoEBuilder):
diff --git a/lmdeploy/pytorch/backends/dlinfer/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/op_backend.py
index 52a8830595..a0f04f34b1 100644
--- a/lmdeploy/pytorch/backends/dlinfer/op_backend.py
+++ b/lmdeploy/pytorch/backends/dlinfer/op_backend.py
@@ -22,9 +22,12 @@ def get_name() -> str:
@classmethod
def get_layer_impl_builder(cls, layer_type: OpType):
"""get dlinfer layer builder."""
- if layer_type == OpType.Attention:
+ if layer_type == OpType.PagedAttention:
from .attention import DlinferAttentionBuilder
return DlinferAttentionBuilder
+ elif layer_type == OpType.FlashAttention:
+ from .flash_attention import DlinferFlashAttentionBuilder
+ return DlinferFlashAttentionBuilder
elif layer_type == OpType.ApplyRotaryEmb:
from .apply_rotary_emb import DlinferApplyRotaryEmbBuilder
return DlinferApplyRotaryEmbBuilder
diff --git a/lmdeploy/pytorch/backends/flash_attention.py b/lmdeploy/pytorch/backends/flash_attention.py
new file mode 100644
index 0000000000..bed3af8d68
--- /dev/null
+++ b/lmdeploy/pytorch/backends/flash_attention.py
@@ -0,0 +1,40 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABC, abstractmethod
+
+from torch import Tensor
+
+
+class FlashAttentionImpl(ABC):
+ """FlashAttention implementation."""
+
+ def forward(self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ q_start_loc: Tensor,
+ q_seqlens: Tensor,
+ kv_start_loc: Tensor,
+ kv_seqlens: Tensor,
+ max_q_seqlen: int = None):
+ """forward."""
+ raise NotImplementedError
+
+
+class FlashAttentionBuilder(ABC):
+ """FlashAttention implementation builder."""
+
+ @staticmethod
+ @abstractmethod
+ def build(
+ num_heads: int,
+ head_dim: int,
+ scale: float = None,
+ num_kv_heads: int = None,
+ v_head_dim: int = None,
+ causal: bool = True,
+ sliding_window: int = None,
+ logical_softcapping: float = None,
+ **kwargs,
+ ) -> FlashAttentionImpl:
+ """build."""
+ raise NotImplementedError
diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py
index 9ab66b26a2..9347995e0b 100644
--- a/lmdeploy/pytorch/backends/graph_runner.py
+++ b/lmdeploy/pytorch/backends/graph_runner.py
@@ -46,3 +46,26 @@ def prepare_inputs_for_generation(
inputs_embeds,
context,
)
+
+ def update_model_metas(
+ self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: torch.Tensor = None,
+ context: StepContext = None,
+ ):
+ """prepare inputs."""
+ if hasattr(self.model, 'update_model_metas'):
+ return self.model.update_model_metas(
+ past_key_values,
+ inputs_embeds,
+ context,
+ )
+
+ return None
+
+ def get_input_processor(self):
+ """get input processor."""
+ if hasattr(self.model, 'get_input_processor'):
+ return self.model.get_input_processor()
+ else:
+ return None
diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py
index 7d72438224..bc95a32be6 100644
--- a/lmdeploy/pytorch/check_env/__init__.py
+++ b/lmdeploy/pytorch/check_env/__init__.py
@@ -1,277 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from logging import Logger
-from typing import List
-
-from lmdeploy.utils import get_logger
-
-
-def _handle_exception(e: Exception,
- mod_name: str,
- logger: Logger,
- message: str = None):
- red_color = '\033[31m'
- reset_color = '\033[0m'
- if message is None:
- message = 'Please ensure it has been installed correctly.'
- logger.debug('Exception', exc_info=1)
- logger.error(f'{type(e).__name__}: {e}')
- logger.error(f'{red_color}'
- f'<{mod_name}> test failed!\n'
- f'{message}'
- f'{reset_color}')
- exit(1)
+from .base import BaseChecker # noqa: F401
def check_env_deeplink(device_type: str):
"""check Deeplink environment."""
- try_import_deeplink(device_type)
+ from .deeplink import DeeplinkChecker
+ checker = DeeplinkChecker(device_type)
+ checker.handle()
def try_import_deeplink(device_type: str):
- """import dlinfer if specific device_type is set."""
- deeplink_device_type_list = [
- 'ascend',
- 'npu',
- 'maca',
- ]
- if device_type in deeplink_device_type_list:
- logger = get_logger('lmdeploy')
- try:
- import dlinfer.framework.lmdeploy_ext # noqa: F401
- except Exception as e:
- _handle_exception(e, 'PyTorch', logger)
-
-
-def check_env_torch():
- """check PyTorch environment."""
- logger = get_logger('lmdeploy')
-
- try:
- logger.debug('Checking environment.')
- import torch
-
- a = torch.tensor([1, 2], device='cuda')
- b = a.new_tensor([3, 4], device='cuda')
- c = a + b
- torch.testing.assert_close(c, a.new_tensor([4, 6]))
- except Exception as e:
- _handle_exception(e, 'PyTorch', logger)
-
-
-MAX_TRITON_VERSION = '3.0.0'
-
-
-def check_env_triton(device: str):
- """check OpenAI Triton environment."""
- from packaging import version
- logger = get_logger('lmdeploy')
-
- msg = (
- 'Please ensure that your device is functioning properly with .\n' # noqa: E501
- 'You can verify your environment by running '
- '`python -m lmdeploy.pytorch.check_env.triton_custom_add`.')
- try:
- logger.debug('Checking environment.')
- import torch
- import triton
- triton_version = version.parse(triton.__version__)
- if triton_version > version.parse(MAX_TRITON_VERSION):
- logger.warning(
- f'Engine has not been tested on triton>{MAX_TRITON_VERSION}.')
-
- from .triton_custom_add import custom_add
- a = torch.tensor([1, 2], device='cuda')
- b = a.new_tensor([3, 4], device='cuda')
- c = custom_add(a, b)
- torch.testing.assert_close(c, a + b)
- except RuntimeError as e:
- ptxas_error = 'device kernel image is invalid'
- if len(e.args) > 0 and ptxas_error in e.args[0]:
- msg = (
- 'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n' # noqa: E501
- 'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209' # noqa: E501
- ' or reinstall the driver.')
- _handle_exception(e, 'Triton', logger, msg)
- except Exception as e:
- _handle_exception(e, 'Triton', logger, msg)
-
- if device == 'cuda':
- device_cap = torch.cuda.get_device_capability()
- TRITON_VER_231 = version.parse('2.3.1')
-
- if device_cap[0] <= 7:
- if triton_version <= TRITON_VER_231:
- err = RuntimeError(
- 'Attention triton kernel does not fully support '
- 'triton<3.0.0 on device with capability<8. '
- 'Please upgrade your triton version.')
- _handle_exception(err, 'Triton', logger)
-
-
-def check_env(device_type: str):
- """check all environment."""
- logger = get_logger('lmdeploy')
- logger.info('Checking environment for PyTorch Engine.')
+ """check Deeplink environment."""
check_env_deeplink(device_type)
- check_env_torch()
- if device_type == 'cuda':
- check_env_triton('cuda')
-
-
-MIN_TRANSFORMERS_VERSION = '4.33.0'
-MAX_TRANSFORMERS_VERSION = '4.44.1'
-
-
-def check_awq(hf_config, device_type):
- """check awq support."""
- logger = get_logger('lmdeploy')
- if device_type == 'cuda':
- quantization_config = getattr(hf_config, 'quantization_config', dict())
- quant_method = quantization_config.get('quant_method', None)
- if quant_method != 'awq':
- return
- try:
- import awq # noqa
- except Exception as e:
- _handle_exception(e, 'autoawq', logger)
-
- try:
- import awq_ext # noqa
- except Exception:
- logger.debug('Exception:', exc_info=1)
- logger.warning('Failed to import `awq_ext`. '
- 'Try reinstall it from source: '
- 'https://github.com/casper-hansen/AutoAWQ_kernels')
-
-
-def check_transformers_version(model_path: str,
- trust_remote_code: bool = True,
- dtype: str = 'auto',
- device_type: str = 'cuda'):
- """check transformers version."""
- from packaging import version
- logger = get_logger('lmdeploy')
-
- def __check_transformers_version():
- """check transformers version."""
- logger.debug('Checking version.')
- trans_version = None
- try:
- import transformers
- trans_version = version.parse(transformers.__version__)
- min_version = version.parse(MIN_TRANSFORMERS_VERSION)
- max_version = version.parse(MAX_TRANSFORMERS_VERSION)
- if trans_version < min_version or trans_version > max_version:
- logger.warning('LMDeploy requires transformers version: '
- f'[{MIN_TRANSFORMERS_VERSION} ~ '
- f'{MAX_TRANSFORMERS_VERSION}], '
- 'but found version: '
- f'{transformers.__version__}')
- except Exception as e:
- _handle_exception(e, 'transformers', logger)
- return transformers, trans_version
-
- def __check_config(trans_version):
- """check config."""
- logger.debug('Checking AutoConfig.from_pretrained.')
- try:
- from transformers import AutoConfig
- config = AutoConfig.from_pretrained(
- model_path, trust_remote_code=trust_remote_code)
- except Exception as e:
- message = (
- f'Load model config with transformers=={trans_version}'
- ' failed. '
- 'Please make sure model can be loaded with transformers API.')
- _handle_exception(e, 'transformers', logger, message=message)
- return config
-
- def __check_model_transformers_version(config, trans_version):
- """check model transformers version."""
- logger.debug('Checking required transformers version.')
- try:
- model_trans_version = getattr(config, 'transformers_version', None)
- if model_trans_version is not None:
- model_trans_version = version.parse(model_trans_version)
- assert trans_version >= model_trans_version, \
- 'Version mismatch.'
- except Exception as e:
- message = (f'model `{model_path}` requires '
- f'transformers version {model_trans_version} '
- f'but transformers {trans_version} is installed.')
- _handle_exception(e, 'transformers', logger, message=message)
-
- def __check_model_dtype_support(config, device_type):
- """Checking model dtype support."""
- logger.debug('Checking dtype support.')
-
- import torch
-
- from lmdeploy.pytorch.config import ModelConfig
- from lmdeploy.utils import is_bf16_supported
-
- try:
- model_config = ModelConfig.from_hf_config(config,
- model_path=model_path,
- dtype=dtype)
- if model_config.dtype == torch.bfloat16:
- assert is_bf16_supported(device_type), (
- 'bf16 is not supported on your device')
- except AssertionError as e:
- message = (
- f'Your device does not support `{model_config.dtype}`. '
- 'You can set `dtype` to float16 in PyTorchEngineConfig or '
- '`--dtype float16` to api_server.\n'
- 'Note that this might have negative effect!')
- _handle_exception(e, 'Model', logger, message=message)
- except Exception as e:
- message = (f'Checking failed with error {e}',
- 'Please send issue to LMDeploy with error logs.')
- _handle_exception(e, 'Model', logger, message=message)
-
- return model_config
-
- _, trans_version = __check_transformers_version()
- config = __check_config(trans_version)
- __check_model_transformers_version(config, trans_version)
- __check_model_dtype_support(config, device_type)
- check_awq(config, device_type)
-
-
-def check_model(model_path: str,
- trust_remote_code: bool = True,
- dtype: str = 'auto',
- device_type: str = 'cuda'):
- """check model requirements."""
- logger = get_logger('lmdeploy')
- logger.info('Checking model.')
- check_transformers_version(model_path, trust_remote_code, dtype,
- device_type)
-
-
-def check_adapter(path: str):
- """check adapter."""
- logger = get_logger('lmdeploy')
- logger.debug(f'Checking : {path}.')
-
- try:
- from peft import PeftConfig
- PeftConfig.from_pretrained(path)
- except Exception as e:
- message = ('Please make sure the adapter can be loaded with '
- '`peft.PeftConfig.from_pretrained`\n')
- err_msg = '' if len(e.args) == 0 else e.args[0]
- if 'got an unexpected keyword argument' in err_msg:
- message += ('Or try remove all unexpected keywords '
- 'in `adapter_config.json`.')
- _handle_exception(e, 'Model', logger, message=message)
-
-
-def check_adapters(adapter_paths: List[str]):
- """check adapters."""
- if len(adapter_paths) <= 0:
- return
- logger = get_logger('lmdeploy')
- logger.info('Checking adapters.')
- for path in adapter_paths:
- check_adapter(path)
diff --git a/lmdeploy/pytorch/check_env/adapter.py b/lmdeploy/pytorch/check_env/adapter.py
new file mode 100644
index 0000000000..bcaf5fd0e3
--- /dev/null
+++ b/lmdeploy/pytorch/check_env/adapter.py
@@ -0,0 +1,31 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import BaseChecker
+
+
+class AdapterChecker(BaseChecker):
+ """check adapter is available."""
+
+ def __init__(self, adapter_path: str, logger=None):
+ super().__init__(logger)
+ self.adapter_path = adapter_path
+
+ def check(self):
+ """check."""
+ path = self.adapter_path
+
+ try:
+ import peft # noqa: F401
+ except Exception as e:
+ self.log_and_exit(e, 'Adapter', message='Failed to import peft.')
+
+ try:
+ from peft import PeftConfig
+ PeftConfig.from_pretrained(path)
+ except Exception as e:
+ message = ('Please make sure the adapter can be loaded with '
+ '`peft.PeftConfig.from_pretrained`\n')
+ err_msg = '' if len(e.args) == 0 else e.args[0]
+ if 'got an unexpected keyword argument' in err_msg:
+ message += ('Or try remove all unexpected keywords '
+ 'in `adapter_config.json`.')
+ self.log_and_exit(e, 'Adapter', message=message)
diff --git a/lmdeploy/pytorch/check_env/base.py b/lmdeploy/pytorch/check_env/base.py
new file mode 100644
index 0000000000..ed5e5a600f
--- /dev/null
+++ b/lmdeploy/pytorch/check_env/base.py
@@ -0,0 +1,62 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from logging import Logger
+from typing import List
+
+from lmdeploy.utils import get_logger
+
+RED_COLOR = '\033[31m'
+RESET_COLOR = '\033[0m'
+
+
+def _red_text(text: str):
+ """red text."""
+ return f'{RED_COLOR}{text}{RESET_COLOR}'
+
+
+class BaseChecker:
+ """base checker."""
+
+ def __init__(self, logger: Logger = None):
+ if logger is None:
+ logger = get_logger('lmdeploy')
+ self.logger = logger
+ self._is_passed = False
+ self._required_checker: List[BaseChecker] = list()
+
+ def get_logger(self):
+ """get logger."""
+ return self.logger
+
+ def register_required_checker(self, checker: 'BaseChecker'):
+ """register_required."""
+ self._required_checker.append(checker)
+
+ def handle(self):
+ """handle check."""
+ is_passed = getattr(self, '_is_passed', False)
+ if not is_passed:
+ checker_name = type(self).__name__
+ self.logger.debug(f'Checking <{checker_name}>:')
+ for checker in self._required_checker:
+ checker.handle()
+ self.check()
+ self.is_passed = True
+
+ def log_and_exit(self,
+ e: Exception = None,
+ mod_name: str = None,
+ message: str = None):
+ logger = self.logger
+ if mod_name is None:
+ mod_name = type(self).__name__
+ if message is None:
+ message = 'Please check your environment.'
+ logger.debug('Exception', exc_info=1)
+ if e is not None:
+ logger.error(f'{type(e).__name__}: {e}')
+ logger.error(f'<{mod_name}> check failed!\n{_red_text(message)}')
+ exit(1)
+
+ def check(self):
+ """check."""
+ raise NotImplementedError('check not implemented.')
diff --git a/lmdeploy/pytorch/check_env/deeplink.py b/lmdeploy/pytorch/check_env/deeplink.py
new file mode 100644
index 0000000000..74ab5a7b87
--- /dev/null
+++ b/lmdeploy/pytorch/check_env/deeplink.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import BaseChecker
+
+deeplink_device_type_list = [
+ 'ascend',
+ 'npu',
+ 'maca',
+]
+
+
+class DeeplinkChecker(BaseChecker):
+ """check pytorch is available."""
+
+ def __init__(self, device_type: str, logger=None) -> None:
+ super().__init__(logger=logger)
+ self.device_type = device_type
+
+ def check(self):
+ """check."""
+ device_type = self.device_type
+ if device_type in deeplink_device_type_list:
+ try:
+ import dlinfer.framework.lmdeploy_ext # noqa: F401
+ except Exception as e:
+ self.log_and_exit(e, 'dlinfer', 'dlinfer is not available.')
diff --git a/lmdeploy/pytorch/check_env/model.py b/lmdeploy/pytorch/check_env/model.py
new file mode 100644
index 0000000000..4b721e50e2
--- /dev/null
+++ b/lmdeploy/pytorch/check_env/model.py
@@ -0,0 +1,117 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from packaging import version
+
+from .base import BaseChecker
+
+
+class ModelChecker(BaseChecker):
+ """check model is available."""
+
+ def __init__(self,
+ model_path: str,
+ trust_remote_code: bool,
+ dtype: str,
+ device_type: str,
+ logger=None) -> None:
+ super().__init__(logger=logger)
+ self.model_path = model_path
+ self.trust_remote_code = trust_remote_code
+ self.device_type = device_type
+ self.dtype = dtype
+
+ def check_config(self, trans_version):
+ """check config."""
+ model_path = self.model_path
+ trust_remote_code = self.trust_remote_code
+ try:
+ from transformers import AutoConfig
+ config = AutoConfig.from_pretrained(
+ model_path, trust_remote_code=trust_remote_code)
+ except Exception as e:
+ message = (
+ f'Load model config with transformers=={trans_version}'
+ ' failed. '
+ 'Please make sure model can be loaded with transformers API.')
+ self.log_and_exit(e, 'transformers', message=message)
+ return config
+
+ def check_trans_version(self, config, trans_version):
+ """check transformers version."""
+ model_path = self.model_path
+ try:
+ model_trans_version = getattr(config, 'transformers_version', None)
+ if model_trans_version is not None:
+ model_trans_version = version.parse(model_trans_version)
+ assert trans_version >= model_trans_version, (
+ 'Version mismatch.')
+ except Exception as e:
+ message = (f'model `{model_path}` requires '
+ f'transformers version {model_trans_version} '
+ f'but transformers {trans_version} is installed.')
+ self.log_and_exit(e, 'transformers', message=message)
+
+ def check_dtype(self, config):
+ """check dtype."""
+ logger = self.get_logger()
+ model_path = self.model_path
+ device_type = self.device_type
+ dtype = self.dtype
+ try:
+ import torch
+
+ from lmdeploy.pytorch.config import ModelConfig
+ from lmdeploy.utils import is_bf16_supported
+ model_config = ModelConfig.from_hf_config(config,
+ model_path=model_path,
+ dtype=dtype)
+ if model_config.dtype == torch.bfloat16:
+ if not is_bf16_supported(device_type):
+ logger.warning('Device does not support bfloat16.')
+ except Exception as e:
+ message = (f'Checking failed with error {e}',
+ 'Please send issue to LMDeploy with error logs.')
+ self.log_and_exit(e, 'Model', message=message)
+
+ def check_awq(self, config):
+ """check awq."""
+ logger = self.get_logger()
+ device_type = self.device_type
+ if device_type != 'cuda':
+ return
+
+ quantization_config = getattr(config, 'quantization_config', dict())
+ quant_method = quantization_config.get('quant_method', None)
+ if quant_method != 'awq':
+ return
+ try:
+ import awq # noqa
+ except Exception as e:
+ self.log_and_exit(e, 'autoawq', logger)
+
+ try:
+ import awq_ext # noqa
+ except Exception as e:
+ logger.debug('Exception:', exc_info=1)
+ self.log_and_exit(
+ e,
+ 'awq_ext',
+ message='Failed to import `awq_ext`. '
+ 'Try reinstall it from source: '
+ 'https://github.com/casper-hansen/AutoAWQ_kernels')
+
+ def check(self):
+ """check."""
+ import transformers
+ trans_version = version.parse(transformers.__version__)
+
+ # config
+ config = self.check_config(trans_version)
+
+ # transformers version
+ self.check_trans_version(config, trans_version)
+
+ # dtype check
+ self.check_dtype(config)
+
+ # awq
+ self.check_awq(config)
diff --git a/lmdeploy/pytorch/check_env/torch.py b/lmdeploy/pytorch/check_env/torch.py
new file mode 100644
index 0000000000..14b24e04a0
--- /dev/null
+++ b/lmdeploy/pytorch/check_env/torch.py
@@ -0,0 +1,21 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import BaseChecker
+
+
+class TorchChecker(BaseChecker):
+ """check pytorch is available."""
+
+ def __init__(self, device: str = 'cuda', logger=None) -> None:
+ super().__init__(logger=logger)
+ self.device = device
+
+ def check(self):
+ """check."""
+ try:
+ import torch
+ a = torch.tensor([1, 2], device=self.device)
+ b = a.new_tensor([3, 4], device=self.device)
+ c = a + b
+ torch.testing.assert_close(c, a.new_tensor([4, 6]))
+ except Exception as e:
+ self.log_and_exit(e, 'PyTorch', 'PyTorch is not available.')
diff --git a/lmdeploy/pytorch/check_env/transformers.py b/lmdeploy/pytorch/check_env/transformers.py
new file mode 100644
index 0000000000..9d97cd6dca
--- /dev/null
+++ b/lmdeploy/pytorch/check_env/transformers.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from packaging import version
+
+from .base import BaseChecker
+
+MIN_TRANSFORMERS_VERSION = '4.33.0'
+MAX_TRANSFORMERS_VERSION = '4.46.1'
+
+
+class TransformersChecker(BaseChecker):
+ """check transformers is available."""
+
+ def check(self):
+ """check."""
+ import transformers
+ logger = self.get_logger()
+ try:
+ trans_version = version.parse(transformers.__version__)
+ min_version = version.parse(MIN_TRANSFORMERS_VERSION)
+ max_version = version.parse(MAX_TRANSFORMERS_VERSION)
+ if trans_version < min_version or trans_version > max_version:
+ logger.warning('LMDeploy requires transformers version: '
+ f'[{MIN_TRANSFORMERS_VERSION} ~ '
+ f'{MAX_TRANSFORMERS_VERSION}], '
+ 'but found version: '
+ f'{transformers.__version__}')
+ except Exception as e:
+ self.log_and_exit(e, 'transformers',
+ 'transformers is not available.')
diff --git a/lmdeploy/pytorch/check_env/triton.py b/lmdeploy/pytorch/check_env/triton.py
new file mode 100644
index 0000000000..4cc58c5492
--- /dev/null
+++ b/lmdeploy/pytorch/check_env/triton.py
@@ -0,0 +1,60 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from packaging import version
+
+from .base import BaseChecker
+
+MAX_TRITON_VERSION = '3.1.0'
+MIN_TRITON_VERSION = '3.0.0'
+
+
+class TritonChecker(BaseChecker):
+ """check triton is available."""
+
+ def check_version(self):
+ """check version."""
+ logger = self.get_logger()
+
+ # version check
+ import triton
+ max_version = version.parse(MAX_TRITON_VERSION)
+ min_version = version.parse(MIN_TRITON_VERSION)
+ triton_version = version.parse(triton.__version__)
+
+ if triton_version > max_version:
+ logger.warning('PytorchEngine has not been tested on '
+ f'triton>{MAX_TRITON_VERSION}.')
+ if triton_version < min_version:
+ msg = (f'triton>={MIN_TRITON_VERSION} is required. '
+ f'Found triton=={triton_version}')
+ self.log_and_exit(mod_name='Triton', message=msg)
+
+ def check(self):
+ """check."""
+ logger = self.get_logger()
+
+ msg = (
+ 'Please ensure that your device is functioning properly with .\n' # noqa: E501
+ 'You can verify your environment by running '
+ '`python -m lmdeploy.pytorch.check_env.triton_custom_add`.')
+ try:
+ logger.debug('Checking environment.')
+ import torch
+
+ from .triton_custom_add import custom_add
+ a = torch.tensor([1, 2], device='cuda')
+ b = a.new_tensor([3, 4], device='cuda')
+ c = custom_add(a, b)
+ torch.testing.assert_close(c, a + b)
+ except RuntimeError as e:
+ ptxas_error = 'device kernel image is invalid'
+ if len(e.args) > 0 and ptxas_error in e.args[0]:
+ msg = (
+ 'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n' # noqa: E501
+ 'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209' # noqa: E501
+ ' or reinstall the driver.')
+ self.log_and_exit(e, 'Triton', msg)
+ except Exception as e:
+ self.log_and_exit(e, 'Triton', msg)
+
+ # version check
+ self.check_version()
diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py
index c350f4b4cf..7783afd970 100644
--- a/lmdeploy/pytorch/config.py
+++ b/lmdeploy/pytorch/config.py
@@ -26,6 +26,10 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str):
return config
torch_dtype = getattr(config.hf_config, 'torch_dtype', None)
+ # deal with case when torch_dtype is not string but torch.dtype
+ if isinstance(torch_dtype, torch.dtype):
+ torch_dtype = str(torch_dtype).split('.')[1]
+
if torch_dtype is None:
_dtype = 'float16' if dtype == 'auto' else dtype
logger.warning('Model config does not have `torch_dtype`,'
@@ -37,8 +41,8 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str):
# change to user specified data type if it is not 'auto'
if dtype == 'auto':
torch_dtype = torch_dtype if torch_dtype in [
- torch.float16, torch.bfloat16
- ] else torch.float16
+ 'float16', 'bfloat16'
+ ] else 'float16'
else:
torch_dtype = dtype
config.dtype = eval(f'torch.{torch_dtype}')
@@ -77,6 +81,7 @@ class CacheConfig:
max_prefill_token_num: int = 4096
enable_prefix_caching: bool = False
quant_policy: Literal[0, 4, 8] = 0
+ device_type: str = 'cuda'
def __post_init__(self):
"""post init."""
@@ -103,7 +108,6 @@ class ModelConfig:
v_head_dim: int = None
sliding_window: int = -1
dtype: torch.dtype = torch.float16
- multi_query_attention: bool = False
vocab_size: int = 40000
hf_config: Any = None
cogvlm_style: bool = False
@@ -117,7 +121,8 @@ def get_head_size(self):
def from_pretrained(cls,
pretrained_model_name_or_path: str,
trust_remote_code: bool = True,
- dtype: str = 'auto'):
+ dtype: str = 'auto',
+ tp: int = 1):
"""Instantiate one of the configuration classes of the library from a
pretrained model configuration.
@@ -137,17 +142,21 @@ def from_pretrained(cls,
pretrained_model_name_or_path)
return cls.from_hf_config(hf_config,
pretrained_model_name_or_path,
- dtype=dtype)
+ dtype=dtype,
+ tp=tp)
@classmethod
def from_hf_config(cls,
hf_config: Any,
model_path: str = None,
- dtype: str = 'auto'):
+ dtype: str = 'auto',
+ tp: int = 1):
"""from huggingface config."""
from lmdeploy.pytorch.configurations import AutoModelConfigBuilder
- model_config = AutoModelConfigBuilder.build(hf_config, model_path)
+ model_config = AutoModelConfigBuilder.build(hf_config,
+ model_path,
+ tp=tp)
if model_config.k_head_dim is None:
assert model_config.head_dim is not None
@@ -156,6 +165,13 @@ def from_hf_config(cls,
assert model_config.head_dim is not None
model_config.v_head_dim = model_config.head_dim
+ # check for tp
+ assert model_config.num_attention_heads % tp == 0
+ if model_config.num_key_value_heads >= tp:
+ assert model_config.num_key_value_heads % tp == 0
+ else:
+ assert tp % model_config.num_key_value_heads == 0
+
# should after setting `hf_config` and `model_arch` attributes
model_config = _update_torch_dtype(model_config, dtype)
diff --git a/lmdeploy/pytorch/configurations/builder.py b/lmdeploy/pytorch/configurations/builder.py
index 89bf51ca46..bafa78ba02 100644
--- a/lmdeploy/pytorch/configurations/builder.py
+++ b/lmdeploy/pytorch/configurations/builder.py
@@ -27,7 +27,7 @@ def condition(cls, hf_config):
f'`condition` of {cls.__name__} not implemented.')
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
from .default import DefaultModelConfigBuilder
@@ -46,8 +46,21 @@ def build(cls, hf_config, model_path: str = None):
logger.debug(f'build model config with {valid_builder.__name__}')
- cfg = valid_builder.build(hf_config, model_path)
+ cfg = valid_builder.build(hf_config, model_path, **kwargs)
if cfg.hf_config is None:
cfg.hf_config = hf_config
return cfg
+
+ @classmethod
+ def update_num_kv_heads(cls, hf_config, tp, num_key_value_heads):
+ """update num kv heads."""
+ # update num_kv_heads for tp mode
+ if tp > 1 and tp > num_key_value_heads:
+ assert tp % num_key_value_heads == 0
+ n_replicate = tp // num_key_value_heads
+ hf_config.num_replicate_key_value_heads = n_replicate
+ num_key_value_heads = tp
+
+ hf_config.num_key_value_heads = num_key_value_heads
+ return num_key_value_heads
diff --git a/lmdeploy/pytorch/configurations/chatglm.py b/lmdeploy/pytorch/configurations/chatglm.py
index 7911c985d5..fbf4d48281 100644
--- a/lmdeploy/pytorch/configurations/chatglm.py
+++ b/lmdeploy/pytorch/configurations/chatglm.py
@@ -12,16 +12,27 @@ def condition(cls, hf_config):
return hf_config.model_type == 'chatglm'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
head_dim = hf_config.hidden_size // hf_config.num_attention_heads
bos_token_id = hf_config.bos_token_id
if bos_token_id is None:
bos_token_id = hf_config.pad_token_id
+
+ if hf_config.multi_query_attention:
+ num_key_value_heads = hf_config.multi_query_group_num
+ else:
+ num_key_value_heads = hf_config.num_attention_heads
+
+ tp = kwargs.get('tp', 1)
+ # update num_kv_heads for tp mode
+ num_key_value_heads = cls.update_num_kv_heads(hf_config, tp,
+ num_key_value_heads)
+
cfg = ModelConfig(hidden_size=hf_config.hidden_size,
num_layers=hf_config.num_layers,
num_attention_heads=hf_config.num_attention_heads,
- num_key_value_heads=hf_config.multi_query_group_num,
+ num_key_value_heads=num_key_value_heads,
bos_token_id=bos_token_id,
eos_token_id=hf_config.eos_token_id,
head_dim=head_dim,
diff --git a/lmdeploy/pytorch/configurations/cogvlm.py b/lmdeploy/pytorch/configurations/cogvlm.py
index b24d92d794..4736dfee69 100644
--- a/lmdeploy/pytorch/configurations/cogvlm.py
+++ b/lmdeploy/pytorch/configurations/cogvlm.py
@@ -12,12 +12,15 @@ def condition(cls, hf_config):
return model_arch == 'CogVLMForCausalLM'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
from lmdeploy.utils import is_bf16_supported
- cfg = DefaultModelConfigBuilder.build(hf_config)
if getattr(hf_config, 'num_multi_query_heads', None):
- cfg.num_key_value_heads = hf_config.num_multi_query_heads
+ hf_config.num_key_value_heads = hf_config.num_multi_query_heads
+ else:
+ hf_config.num_key_value_heads = hf_config.num_attention_heads
+
+ cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
cfg.cogvlm_style = True
torch_dtype = 'bfloat16' if is_bf16_supported() else 'float16'
hf_config.torch_dtype = torch_dtype
diff --git a/lmdeploy/pytorch/configurations/dbrx.py b/lmdeploy/pytorch/configurations/dbrx.py
index 2c8128a5a6..dcc1222b0d 100644
--- a/lmdeploy/pytorch/configurations/dbrx.py
+++ b/lmdeploy/pytorch/configurations/dbrx.py
@@ -12,7 +12,7 @@ def condition(cls, hf_config):
return hf_config.model_type == 'dbrx'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
hidden_size = hf_config.d_model
num_heads = hf_config.n_heads
diff --git a/lmdeploy/pytorch/configurations/deepseek_v2.py b/lmdeploy/pytorch/configurations/deepseek_v2.py
index 37aa4b0d69..d1f0844ad5 100644
--- a/lmdeploy/pytorch/configurations/deepseek_v2.py
+++ b/lmdeploy/pytorch/configurations/deepseek_v2.py
@@ -12,13 +12,19 @@ def condition(cls, hf_config):
return hf_config.model_type == 'deepseek_v2'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
head_dim = (hf_config.kv_lora_rank + hf_config.qk_rope_head_dim)
k_head_dim = head_dim
v_head_dim = 0
num_attention_heads = hf_config.num_attention_heads
+ # multi query attn
num_key_value_heads = 1
+ tp = kwargs.get('tp', 1)
+ # update num_kv_heads for tp mode
+ num_key_value_heads = cls.update_num_kv_heads(hf_config, tp,
+ num_key_value_heads)
+
return ModelConfig(hidden_size=hf_config.hidden_size,
num_layers=hf_config.num_hidden_layers,
num_attention_heads=num_attention_heads,
@@ -28,5 +34,4 @@ def build(cls, hf_config, model_path: str = None):
head_dim=head_dim,
k_head_dim=k_head_dim,
v_head_dim=v_head_dim,
- vocab_size=hf_config.vocab_size,
- multi_query_attention=True)
+ vocab_size=hf_config.vocab_size)
diff --git a/lmdeploy/pytorch/configurations/default.py b/lmdeploy/pytorch/configurations/default.py
index 1f84b810ea..d1337a241e 100644
--- a/lmdeploy/pytorch/configurations/default.py
+++ b/lmdeploy/pytorch/configurations/default.py
@@ -12,7 +12,7 @@ def condition(cls, hf_config):
return True
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
head_dim = hf_config.hidden_size // hf_config.num_attention_heads
num_attention_heads = hf_config.num_attention_heads
@@ -23,6 +23,11 @@ def build(cls, hf_config, model_path: str = None):
if use_sliding_window:
sliding_window = getattr(hf_config, 'sliding_window',
sliding_window) or -1
+ tp = kwargs.get('tp', 1)
+ # update num_kv_heads for tp mode
+ num_key_value_heads = cls.update_num_kv_heads(hf_config, tp,
+ num_key_value_heads)
+
return ModelConfig(
hidden_size=hf_config.hidden_size,
num_layers=hf_config.num_hidden_layers,
diff --git a/lmdeploy/pytorch/configurations/falcon.py b/lmdeploy/pytorch/configurations/falcon.py
index db4d00e397..a4c8d4d44f 100644
--- a/lmdeploy/pytorch/configurations/falcon.py
+++ b/lmdeploy/pytorch/configurations/falcon.py
@@ -12,7 +12,7 @@ def condition(cls, hf_config):
return hf_config.model_type == 'falcon'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build falcon."""
num_attention_heads = hf_config.num_attention_heads
if hf_config.new_decoder_architecture:
@@ -24,6 +24,12 @@ def build(cls, hf_config, model_path: str = None):
else:
# rw-1b, MHA
kv_head = num_attention_heads
+
+ tp = kwargs.get('tp', 1)
+ # update num_kv_heads for tp mode
+ kv_head = cls.update_num_kv_heads(hf_config, tp, kv_head)
+ hf_config.num_kv_heads = kv_head
+
head_dim = hf_config.hidden_size // num_attention_heads
return ModelConfig(
hidden_size=hf_config.hidden_size,
@@ -33,6 +39,5 @@ def build(cls, hf_config, model_path: str = None):
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
head_dim=head_dim,
- multi_query_attention=hf_config.multi_query,
vocab_size=hf_config.vocab_size,
)
diff --git a/lmdeploy/pytorch/configurations/gemma.py b/lmdeploy/pytorch/configurations/gemma.py
index 338eaee6d0..d49fdbd96c 100644
--- a/lmdeploy/pytorch/configurations/gemma.py
+++ b/lmdeploy/pytorch/configurations/gemma.py
@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from lmdeploy.pytorch.config import ModelConfig
-
from .builder import AutoModelConfigBuilder
+from .default import DefaultModelConfigBuilder
class GemmaModelConfigBuilder(AutoModelConfigBuilder):
@@ -12,13 +11,8 @@ def condition(cls, hf_config):
return hf_config.model_type in ['gemma', 'gemma2']
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build gemma."""
- return ModelConfig(hidden_size=hf_config.hidden_size,
- num_layers=hf_config.num_hidden_layers,
- num_attention_heads=hf_config.num_attention_heads,
- num_key_value_heads=hf_config.num_key_value_heads,
- bos_token_id=hf_config.bos_token_id,
- eos_token_id=hf_config.eos_token_id,
- head_dim=hf_config.head_dim,
- vocab_size=hf_config.vocab_size)
+ cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
+ cfg.head_dim = hf_config.head_dim
+ return cfg
diff --git a/lmdeploy/pytorch/configurations/internvl.py b/lmdeploy/pytorch/configurations/internvl.py
index 76b4187c5f..ffff0a0e15 100644
--- a/lmdeploy/pytorch/configurations/internvl.py
+++ b/lmdeploy/pytorch/configurations/internvl.py
@@ -11,8 +11,9 @@ def condition(cls, hf_config):
return hf_config.architectures[0] == 'InternVLChatModel'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build llava hf."""
- cfg = DefaultModelConfigBuilder.build(hf_config.llm_config)
+ cfg = DefaultModelConfigBuilder.build(hf_config.llm_config, model_path,
+ **kwargs)
cfg.hf_config = hf_config
return cfg
diff --git a/lmdeploy/pytorch/configurations/llava.py b/lmdeploy/pytorch/configurations/llava.py
deleted file mode 100644
index aaeeeeadfe..0000000000
--- a/lmdeploy/pytorch/configurations/llava.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .builder import AutoModelConfigBuilder
-from .default import DefaultModelConfigBuilder
-
-
-class LlavaModelConfigBuilder(AutoModelConfigBuilder):
-
- @classmethod
- def condition(cls, hf_config):
- """config."""
- return hf_config.architectures[0] in [
- 'LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM'
- ]
-
- @classmethod
- def build(cls, hf_config, model_path: str = None):
- """build."""
- arch = hf_config.architectures[0]
- if arch in ['LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM']:
- from llava.model.language_model.llava_llama import LlavaConfig
-
- # reload hf_config due to model_type='llava' is already
- # registered in transformers
- hf_config = LlavaConfig.from_pretrained(model_path)
- cfg = DefaultModelConfigBuilder.build(hf_config)
- return cfg
diff --git a/lmdeploy/pytorch/configurations/llava_hf.py b/lmdeploy/pytorch/configurations/llava_hf.py
index 4cc007e313..5334eaec25 100644
--- a/lmdeploy/pytorch/configurations/llava_hf.py
+++ b/lmdeploy/pytorch/configurations/llava_hf.py
@@ -15,7 +15,7 @@ def condition(cls, hf_config):
]
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build llava hf."""
text_config = hf_config.text_config
hidden_size = getattr(text_config, 'hidden_size', 4096)
diff --git a/lmdeploy/pytorch/configurations/minicpm3.py b/lmdeploy/pytorch/configurations/minicpm3.py
index 7cde51bd42..857673aab3 100644
--- a/lmdeploy/pytorch/configurations/minicpm3.py
+++ b/lmdeploy/pytorch/configurations/minicpm3.py
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from lmdeploy.pytorch.config import ModelConfig
from .builder import AutoModelConfigBuilder
+from .default import DefaultModelConfigBuilder
class MiniCPM3ModelConfigBuilder(AutoModelConfigBuilder):
@@ -12,21 +12,13 @@ def condition(cls, hf_config):
return hf_config.architectures[0] in ['MiniCPM3ForCausalLM']
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
head_dim = (hf_config.qk_nope_head_dim + hf_config.qk_rope_head_dim)
- k_head_dim = head_dim
- v_head_dim = head_dim
- num_attention_heads = hf_config.num_attention_heads
- num_key_value_heads = hf_config.num_key_value_heads
- return ModelConfig(hidden_size=hf_config.hidden_size,
- num_layers=hf_config.num_hidden_layers,
- num_attention_heads=num_attention_heads,
- num_key_value_heads=num_key_value_heads,
- bos_token_id=hf_config.bos_token_id,
- eos_token_id=hf_config.eos_token_id,
- head_dim=head_dim,
- k_head_dim=k_head_dim,
- v_head_dim=v_head_dim,
- vocab_size=hf_config.vocab_size,
- multi_query_attention=False)
+
+ cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
+ cfg.head_dim = head_dim
+ cfg.k_head_dim = head_dim
+ cfg.v_head_dim = head_dim
+
+ return cfg
diff --git a/lmdeploy/pytorch/configurations/mllama.py b/lmdeploy/pytorch/configurations/mllama.py
index 2383c92c50..e56e0fbed4 100644
--- a/lmdeploy/pytorch/configurations/mllama.py
+++ b/lmdeploy/pytorch/configurations/mllama.py
@@ -11,8 +11,9 @@ def condition(cls, hf_config):
return hf_config.architectures[0] == 'MllamaForConditionalGeneration'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build llava hf."""
- cfg = DefaultModelConfigBuilder.build(hf_config.text_config)
+ cfg = DefaultModelConfigBuilder.build(hf_config.text_config,
+ model_path, **kwargs)
cfg.hf_config = hf_config
return cfg
diff --git a/lmdeploy/pytorch/configurations/qwen.py b/lmdeploy/pytorch/configurations/qwen.py
index 05ac77c1d1..eda726de43 100644
--- a/lmdeploy/pytorch/configurations/qwen.py
+++ b/lmdeploy/pytorch/configurations/qwen.py
@@ -11,10 +11,10 @@ def condition(cls, hf_config):
return hf_config.model_type == 'qwen'
@classmethod
- def build(cls, hf_config, model_path: str = None):
+ def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
from lmdeploy.utils import is_bf16_supported
- cfg = DefaultModelConfigBuilder.build(hf_config)
+ cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
if cfg.bos_token_id is None:
cfg.bos_token_id = 151644
if cfg.eos_token_id is None:
diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py
index e393adeed3..e3f97cfe46 100644
--- a/lmdeploy/pytorch/engine/cache_engine.py
+++ b/lmdeploy/pytorch/engine/cache_engine.py
@@ -44,7 +44,13 @@ def __init__(
self.num_layers = model_config.num_layers
self.kv_cache_dtype = model_config.dtype
if cache_config.quant_policy > 0:
- self.kv_cache_dtype = torch.uint8
+ if self.cache_config.device_type in ['cuda']:
+ self.kv_cache_dtype = torch.uint8
+ elif self.cache_config.device_type in ['ascend', 'npu']:
+ self.kv_cache_dtype = torch.int8
+ else:
+ raise ValueError(
+ f'unsupported device_type {self.cache_config.device_type}')
# Initialize the cache.
self.local_gpu_cache = self.allocate_gpu_cache()
@@ -92,7 +98,7 @@ def _get_key_block_shape_impl(cls,
attn_backend = get_backend()
dtype = model_config.dtype
num_heads = model_config.num_key_value_heads
- if local and not model_config.multi_query_attention:
+ if local:
assert num_heads % world_size == 0, \
f'num_heads: {num_heads}, world_size: {world_size}'
num_heads = num_heads // world_size
@@ -115,7 +121,7 @@ def _get_value_block_shape_impl(cls,
attn_backend = get_backend()
dtype = model_config.dtype
num_heads = model_config.num_key_value_heads
- if local and not model_config.multi_query_attention:
+ if local:
assert num_heads % world_size == 0, \
f'num_heads: {num_heads}, world_size: {world_size}'
num_heads = num_heads // world_size
@@ -202,7 +208,7 @@ def allocate_gpu_cache(self):
def allocate_cpu_cache(self):
"""allocate caches on Host."""
- caches = self._allocate_cache(self.num_gpu_blocks, 'cpu')
+ caches = self._allocate_cache(self.num_cpu_blocks, 'cpu')
self.full_cpu_cache = caches
self.local_cpu_cache = list(zip(*caches))
diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py
index b7a803a7a7..e06e0cf80a 100644
--- a/lmdeploy/pytorch/engine/engine.py
+++ b/lmdeploy/pytorch/engine/engine.py
@@ -8,19 +8,17 @@
import numpy as np
import torch
-from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig,
- ResponseType)
+from lmdeploy.messages import PytorchEngineConfig, ResponseType
from lmdeploy.utils import (get_logger, get_max_batch_size, get_model,
logging_timer)
from ..adapter.adapter import AdapterManager
-from ..check_env import check_adapters, check_env, check_model
from ..config import BackendConfig, CacheConfig, SchedulerConfig
from ..devices import DeviceContext, get_device_manager
-from ..messages import (InputEmbeddingRangeType, InputEmbeddingType,
- MessageStatus, SchedulerSequence)
-from ..model_inputs import ModelInputs, MRopeModelInputs, VisionModelInputs
+from ..messages import MessageStatus, SchedulerSequence
+from ..model_inputs import ModelInputs, VisionModelInputs
from ..paging import Scheduler
+from .engine_checker import EngineChecker
from .logits_process import FusedLogitsProcessor, SamplingInputs
from .model_agent import build_model_agent
from .request import Request, RequestManager, RequestType, Response
@@ -78,6 +76,40 @@ def _check_finish(scheduler: Scheduler, current_iter: int):
return False
+def _build_scheduler_config(engine_config: PytorchEngineConfig):
+ """build scheduler config."""
+ scheduler_config = SchedulerConfig(
+ max_batches=engine_config.max_batch_size,
+ max_session_len=engine_config.session_len,
+ prefill_interval=engine_config.prefill_interval)
+ return scheduler_config
+
+
+def _build_cache_config(engine_config: PytorchEngineConfig):
+ """build cache config."""
+ cache_config = CacheConfig(
+ max_batches=engine_config.max_batch_size,
+ block_size=engine_config.block_size,
+ num_cpu_blocks=engine_config.num_cpu_blocks,
+ num_gpu_blocks=engine_config.num_gpu_blocks,
+ cache_max_entry_count=engine_config.cache_max_entry_count,
+ max_prefill_token_num=engine_config.max_prefill_token_num,
+ enable_prefix_caching=engine_config.enable_prefix_caching,
+ quant_policy=engine_config.quant_policy,
+ device_type=engine_config.device_type,
+ )
+ return cache_config
+
+
+def _build_backend_config(engine_config: PytorchEngineConfig):
+ """build backend config."""
+ backend_config = BackendConfig(
+ eager_mode=engine_config.eager_mode,
+ device_type=engine_config.device_type,
+ )
+ return backend_config
+
+
class Engine:
"""The inference engine of lmdeploy pytorch.
@@ -95,43 +127,23 @@ def __init__(self,
engine_config = PytorchEngineConfig()
else:
engine_config = copy.deepcopy(engine_config)
- check_env(engine_config.device_type)
- check_model(model_path, trust_remote_code, engine_config.dtype,
- engine_config.device_type)
if engine_config.max_batch_size is None:
engine_config.max_batch_size = get_max_batch_size(
engine_config.device_type)
- adapters = engine_config.adapters
- if adapters is not None:
- check_adapters(list(adapters.values()))
- assert engine_config.max_batch_size > 0, 'max_batch_size should be' \
- f' greater than 0, but got {engine_config.max_batch_size}'
- assert engine_config.dtype in ['auto', 'float16', 'bfloat16'], \
- f'unsupported specified data type {engine_config.dtype}'
+ checker = EngineChecker(model_path=model_path,
+ engine_config=engine_config,
+ trust_remote_code=trust_remote_code,
+ logger=logger)
+ checker.handle()
+
+ adapters = engine_config.adapters
self.engine_config = engine_config
self.tp = engine_config.tp
self.device_context = DeviceContext(
device_type=engine_config.device_type)
- scheduler_config = SchedulerConfig(
- max_batches=engine_config.max_batch_size,
- max_session_len=engine_config.session_len,
- prefill_interval=engine_config.prefill_interval)
-
- # block_size = 1 to enable unified paging
- cache_config = CacheConfig(
- max_batches=engine_config.max_batch_size,
- block_size=engine_config.block_size,
- num_cpu_blocks=engine_config.num_cpu_blocks,
- num_gpu_blocks=engine_config.num_gpu_blocks,
- cache_max_entry_count=engine_config.cache_max_entry_count,
- max_prefill_token_num=engine_config.max_prefill_token_num,
- enable_prefix_caching=engine_config.enable_prefix_caching,
- quant_policy=engine_config.quant_policy,
- )
-
if not os.path.exists(model_path):
model_path = get_model(model_path, engine_config.download_dir,
engine_config.revision)
@@ -140,10 +152,9 @@ def __init__(self,
if adapters is not None and len(adapters) > 0:
adapters = self._download_adapters(adapters, engine_config)
- backend_config = BackendConfig(
- eager_mode=engine_config.eager_mode,
- device_type=engine_config.device_type,
- )
+ scheduler_config = _build_scheduler_config(engine_config)
+ cache_config = _build_cache_config(engine_config)
+ backend_config = _build_backend_config(engine_config)
with get_device_manager().context(self.device_context):
self.model_agent = build_model_agent(
@@ -156,6 +167,8 @@ def __init__(self,
dtype=engine_config.dtype,
custom_module_map=engine_config.custom_module_map)
+ self.input_processor = self.model_agent.get_input_processor()
+
cache_config = self.model_agent.cache_config
self.adapter_manager = self._build_adapter_manager(adapters)
self.scheduler = Scheduler(scheduler_config, cache_config)
@@ -171,7 +184,6 @@ def __init__(self,
# create main thread
self._start_loop()
self._create_buffers()
- self.engine_instance = self.create_instance()
self._output_stream = torch.cuda.Stream()
@classmethod
@@ -316,6 +328,10 @@ def _on_end_session(self, reqs: Request, **kwargs):
def _on_add_message(self, reqs: Request, **kwargs):
"""on add message callback."""
+ self._msg_preprocess_inque.put_nowait(reqs)
+
+ def _add_message(self, que):
+
def __update_bad_words(msg):
"""update bad words."""
sampling_param = msg.sampling_param
@@ -337,6 +353,11 @@ def __update_max_new_tokens(msg):
sampling_param.max_new_tokens,
max_session_len - msg.num_all_tokens())
+ if que.qsize() == 0:
+ return
+
+ reqs = que.get_nowait()
+
for req in reqs:
session_id = req.data['session_id']
if session_id not in self.scheduler.sessions:
@@ -354,11 +375,8 @@ def __update_max_new_tokens(msg):
sampling_param=req.data['sampling_param'],
adapter_name=req.data['adapter_name'],
return_logits=req.data.get('return_logits', False),
+ multimodals=req.data.get('input_multimodals'),
input_embeddings=req.data.get('input_embeddings'),
- mrope_position_ids=req.data.get('mrope_position_ids'),
- mrope_position_delta=req.data.get('mrope_position_delta'),
- cross_attention_states=req.data.get(
- 'cross_attention_states'),
)
msg = next(iter(sess.sequences.values()))
__update_bad_words(msg)
@@ -366,9 +384,11 @@ def __update_max_new_tokens(msg):
self.scheduler.add_sequence(msg)
else:
msg = next(iter(sess.sequences.values()))
- msg.update_token_ids(req.data['token_ids'],
- req.data.get('input_embeddings'),
- req.data.get('cross_attention_states'))
+ msg.update_token_ids(
+ req.data['token_ids'],
+ multimodals=req.data.get('input_multimodals'),
+ embeddings=req.data.get('input_embeddings'),
+ )
msg.num_new_tokens = 0
msg.sampling_param = req.data['sampling_param']
msg.return_logits = req.data.get('return_logits', False)
@@ -414,7 +434,6 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool):
seq_length = self._seq_length_buf[:batch_size]
max_q_seq_length = seq_length.max().item()
- # TODO: get block offsets is slow when block_size = 1
block_offsets = self.scheduler.get_block_tables(messages)
block_offsets = _tensorlize_block_offsets(block_offsets)
@@ -432,13 +451,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool):
num_ignored_history = [msg.num_ignored_history for msg in messages]
num_ignored_history = torch.tensor(num_ignored_history)
- def __get_cogvlm_image_info():
- """Get cogvlm history image info for position ids."""
- history_image_nums = torch.LongTensor(
- [msg.history_image_num for msg in messages])
- history_image_token_lengths = torch.LongTensor(
- [msg.history_image_token_len for msg in messages])
- return history_image_nums, history_image_token_lengths
+ model_metas = [msg.model_meta for msg in messages]
def __get_vlm_embeddings():
"""get vlm input embeddings and indexings."""
@@ -463,25 +476,9 @@ def __get_vlm_embeddings():
return (input_embeddings, input_embedding_indexing,
input_embedding_ranges)
- def __get_mrope_inputs():
- """get multimodal rotary position inputs."""
- position_ids = [msg.mrope_position_ids for msg in messages]
- deltas = [msg.mrope_position_delta for msg in messages]
- return MRopeModelInputs(position_ids=position_ids, deltas=deltas)
-
# for inputs with embeddings
history_image_nums = None
history_image_token_lengths = None
- # only for cogvlm
- if self.model_config.cogvlm_style:
- (history_image_nums,
- history_image_token_lengths) = __get_cogvlm_image_info()
- # only for qwen2_vl
- mrope_inputs = None
- has_mrope_params = any(
- [msg.mrope_position_ids is not None for msg in messages])
- if has_mrope_params:
- mrope_inputs = __get_mrope_inputs()
input_embeddings = None
input_embedding_indexing = None
@@ -492,25 +489,40 @@ def __get_mrope_inputs():
(input_embeddings, input_embedding_indexing,
input_embedding_ranges) = __get_vlm_embeddings()
+ input_multimodals = None
+ has_multimodal = any(
+ [not msg.history_multimodals.empty() for msg in messages])
+ if has_multimodal:
+ has_multimodal = False
+ input_multimodals = [
+ msg.get_input_multimodals() for msg in messages
+ ]
+ for input_mm in input_multimodals:
+ for val in input_mm.values():
+ if len(val) > 0:
+ has_multimodal = True
+ break
+ if has_multimodal:
+ break
+
vision_embedding_inputs = None
- if has_embedding or history_image_nums is not None:
+ if has_embedding or has_multimodal or history_image_nums is not None:
vision_embedding_inputs = VisionModelInputs(
history_lengths=history_lengths,
history_image_nums=history_image_nums,
history_image_token_lengths=history_image_token_lengths,
input_embeddings=input_embeddings,
input_embedding_indexing=input_embedding_indexing,
- input_embedding_ranges=input_embedding_ranges)
-
- # only for mllama
- cross_attention_states = None
- history_cross_kv_seqlens = None
- if any([msg.cross_attention_states is not None for msg in messages]):
- cross_attention_states = [
- msg.cross_attention_states for msg in messages
- ]
- history_cross_kv_seqlens = torch.tensor(
- [msg.history_cross_kv_seqlens for msg in messages])
+ input_embedding_ranges=input_embedding_ranges,
+ input_multimodals=input_multimodals)
+
+ # cross
+ cross_length = torch.tensor([msg.num_cross for msg in messages])
+ history_cross_length = torch.tensor(
+ [msg.num_history_cross for msg in messages])
+ if (cross_length + history_cross_length).max().item() == 0:
+ cross_length = None
+ history_cross_length = None
return ModelInputs(
input_ids=input_ids,
@@ -521,9 +533,9 @@ def __get_mrope_inputs():
num_ignored_history=num_ignored_history,
local_adapter_ids=local_adapter_ids,
vision_inputs=vision_embedding_inputs,
- mrope_inputs=mrope_inputs,
- cross_attention_states=cross_attention_states,
- history_cross_kv_seqlens=history_cross_kv_seqlens,
+ cross_length=cross_length,
+ history_cross_length=history_cross_length,
+ model_metas=model_metas,
)
def _batch_stopping_criteria(self, token_ids: torch.Tensor,
@@ -567,11 +579,15 @@ def __get_last_logits():
@logging_timer('UpdateRunning', logger)
def update_running(self, running: SeqList, next_token_ids: torch.Tensor,
- stopped: torch.Tensor):
+ stopped: torch.Tensor, model_metas: List[Dict[str,
+ Any]]):
"""update scheduler."""
+ if model_metas is None:
+ model_metas = [None] * len(running)
next_token_ids = next_token_ids.numpy()
eos_token_id = self.model_config.eos_token_id
- for token, msg, stop in zip(next_token_ids, running, stopped):
+ for token, msg, stop, model_meta in zip(next_token_ids, running,
+ stopped, model_metas):
if msg.status != MessageStatus.RUNNING:
continue
update_token = token
@@ -580,7 +596,7 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor,
update_token = _EMPTY_TOKEN
else:
msg.num_new_tokens += 1
- msg.update_token_ids(update_token)
+ msg.update_token_ids(update_token, model_meta=model_meta)
if stop:
msg.status = MessageStatus.STOPPED
@@ -646,12 +662,14 @@ async def __long_context_single_forward(inputs):
batch_size = seq_len.size(0)
assert batch_size == 1
- new_inputs = inputs.split(max_prefill_token_num,
- self.cache_config.block_size)
+ new_inputs = inputs.split(max_prefill_token_num)
+ model_metas = new_inputs[0].model_metas
output_gather = _OutputGather(max_seq_len)
for inp in new_inputs:
+ inp.model_metas = model_metas
tmp_out = await __forward(inp)
+ model_metas = tmp_out.get('model_metas')
output_gather.gather(tmp_out)
tmp_out.pop('hidden_states', None)
tmp_out['hidden_states'] = output_gather.get_output()
@@ -673,9 +691,10 @@ async def __long_context_single_forward(inputs):
ret['logits'] = logits
return ret
- def _make_infer_outputs(self, next_token_ids: torch.LongTensor,
- logits: torch.Tensor, stopped: torch.Tensor,
- event: torch.cuda.Event):
+ async def _make_infer_outputs(self, next_token_ids: torch.LongTensor,
+ logits: torch.Tensor, stopped: torch.Tensor,
+ model_metas: List[Dict[str, Any]],
+ event: torch.cuda.Event):
"""make infer output."""
def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence,
@@ -696,15 +715,16 @@ def __get_q_start_loc():
else:
return seq_length.cumsum(0) - seq_length
+ while not event.query():
+ await asyncio.sleep(0.001)
with torch.cuda.stream(self._output_stream):
- event.wait()
next_token_ids = next_token_ids.cpu()
stopped = stopped.cpu()
running = self._running
is_run = [seq.status == MessageStatus.RUNNING for seq in running]
stopped = stopped.tolist()
- self.update_running(running, next_token_ids, stopped)
+ self.update_running(running, next_token_ids, stopped, model_metas)
# generate output
next_token_ids = next_token_ids.tolist()
@@ -762,8 +782,7 @@ def __update_inputs(next_token_ids):
logger.debug(': '
f'batch_size={inputs.seq_length.size(0)} '
f'num_tokens={inputs.input_ids.size(-1)}')
- if self.gpu_count == 1:
- inputs = inputs.to_device('cuda')
+ inputs = inputs.to_device('cuda')
is_decoding = inputs.is_decoding
if all_ids is not None:
all_ids = all_ids.cuda()
@@ -794,13 +813,16 @@ def __update_inputs(next_token_ids):
next_token_ids, sampling_inputs.stop_words, num_appendable_ids)
# send output
+ model_metas = output.get('model_metas')
finish = (idx == loop_count - 1)
finish = finish or _check_finish(self.scheduler, idx)
event = torch.cuda.Event()
event.record()
- output = (next_token_ids, logits, stopped, event)
+ output = (next_token_ids, logits, stopped, model_metas, event)
output_que.put_nowait((finish, output))
+ inputs.model_metas = model_metas
+
if finish:
break
@@ -810,6 +832,36 @@ def __update_inputs(next_token_ids):
swap_out_map = dict()
__update_inputs(next_token_ids)
+ @torch.inference_mode()
+ async def _async_loop_preprocess_message(self, inque, outque):
+ """preprocess msg."""
+ while True:
+ reqs = await inque.get()
+
+ for req in reqs:
+ req_data = req.data
+ if req_data.get('input_multimodals', None) is None:
+ continue
+ elif self.input_processor is None:
+ logger.warning('Do not support Multimodal inputs.')
+ continue
+ input_ids = req_data['token_ids']
+ input_multimodals = req_data['input_multimodals']
+ if len(input_multimodals) == 0:
+ req_data['input_multimodals'] = None
+ continue
+ result = self.input_processor.preprocess_input(
+ input_ids, input_multimodals)
+
+ input_ids = result.input_ids
+ input_multimodals = result.input_multimodals
+
+ req_data['token_ids'] = input_ids
+ req_data['input_multimodals'] = input_multimodals
+
+ if len(reqs) > 0:
+ outque.put_nowait(reqs)
+
@torch.inference_mode()
async def _async_loop_background(self, in_que: asyncio.Queue,
out_que: asyncio.Queue):
@@ -918,6 +970,10 @@ async def _async_loop(self):
Each engine instance would communicate with the engine by queue.
"""
+
+ self._msg_preprocess_inque = asyncio.Queue()
+ self._msg_preprocess_outque = asyncio.Queue()
+
prefill_interval = self.scheduler_config.prefill_interval
in_que = asyncio.Queue()
out_que = asyncio.Queue()
@@ -926,6 +982,12 @@ async def _async_loop(self):
name='MainLoopBackground')
loop_background.add_done_callback(_raise_exception_on_finish)
+ loop_msg_proc = asyncio.get_event_loop().create_task(
+ self._async_loop_preprocess_message(self._msg_preprocess_inque,
+ self._msg_preprocess_outque),
+ name='MainLoopPreprocessMessage')
+ loop_msg_proc.add_done_callback(_raise_exception_on_finish)
+
def __send_resp(out: InferOutput):
"""send response."""
resp_type = (ResponseType.FINISH
@@ -957,13 +1019,14 @@ async def __step():
while not finish:
if self.req_manager.has_requests():
self.req_manager.step()
+ self._add_message(self._msg_preprocess_outque)
finish, out = await out_que.get()
try:
if isinstance(out, Exception):
raise out
- next_token_ids, logits, stopped, event = out
- step_outputs = self._make_infer_outputs(
- next_token_ids, logits, stopped, event)
+ (next_token_ids, logits, stopped, model_metas, event) = out
+ step_outputs = await self._make_infer_outputs(
+ next_token_ids, logits, stopped, model_metas, event)
__send_resps(step_outputs)
except Exception as e:
raise e
@@ -973,6 +1036,7 @@ async def __step():
while True:
if self.req_manager.has_requests():
self.req_manager.step()
+ self._add_message(self._msg_preprocess_outque)
if not self.scheduler.has_unfinished():
await asyncio.sleep(0.01)
@@ -996,78 +1060,3 @@ def create_instance(self, cuda_stream_id=0):
"""
from .engine_instance import EngineInstance
return EngineInstance(self)
-
- async def async_batched_infer(
- self,
- session_ids: List[int],
- token_ids: List[List[int]] = None,
- gen_config: GenerationConfig = None,
- adapter_names: List[str] = None,
- keep_cache: bool = False,
- input_embeddings: List[InputEmbeddingType] = None,
- input_embedding_ranges: List[InputEmbeddingRangeType] = None):
- """Send inference request.
-
- Args:
- session_ids (List[int]): The session id.
- token_ids (List[int]): The input token ids.
- gen_config (GenerationConfig): The sampling parameters.
- adapter_names (List[str]): The name of the adapters.
- keep_cache (bool): Keep kv cache after infer.
-
- Returns:
- int: Error flags. 0 if success.
- List[int]: The streaming output tokens.
- int: The number of the output tokens.
- """
- return await self.engine_instance.async_batched_infer(
- session_ids=session_ids,
- token_ids=token_ids,
- gen_config=gen_config,
- adapter_names=adapter_names,
- input_embeddings=input_embeddings,
- input_embedding_ranges=input_embedding_ranges,
- keep_cache=keep_cache)
-
- def batched_infer(
- self,
- session_ids: List[int],
- token_ids: List[List[int]] = None,
- gen_config: GenerationConfig = None,
- adapter_names: List[str] = None,
- keep_cache: bool = False,
- input_embeddings: List[InputEmbeddingType] = None,
- input_embedding_ranges: List[InputEmbeddingRangeType] = None):
- """batched infer."""
- return self.engine_instance.batched_infer(
- session_ids=session_ids,
- token_ids=token_ids,
- gen_config=gen_config,
- adapter_names=adapter_names,
- input_embeddings=input_embeddings,
- input_embedding_ranges=input_embedding_ranges,
- keep_cache=keep_cache)
-
- async def async_add_session(self, session_id: int):
- """Add new session."""
- return await self.engine_instance._async_try_add_session(session_id)
-
- def add_session(self, session_id: int):
- """Add new session."""
- return self.engine_instance._try_add_session(session_id)
-
- async def async_cancel(self, session_id: int):
- """Stop the given session."""
- return await self.engine_instance.async_cancel(session_id)
-
- def cancel(self, session_id: int):
- """Add new session."""
- return self.engine_instance.cancel(session_id)
-
- async def async_end(self, session_id: int):
- """End the given session."""
- return await self.engine_instance.async_end(session_id)
-
- def end(self, session_id: int):
- """Add new session."""
- return self.engine_instance.end(session_id)
diff --git a/lmdeploy/pytorch/engine/engine_checker.py b/lmdeploy/pytorch/engine/engine_checker.py
new file mode 100644
index 0000000000..7276a51fbc
--- /dev/null
+++ b/lmdeploy/pytorch/engine/engine_checker.py
@@ -0,0 +1,77 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from lmdeploy.messages import PytorchEngineConfig
+
+from ..check_env.adapter import AdapterChecker
+from ..check_env.base import BaseChecker
+from ..check_env.model import ModelChecker
+from ..check_env.torch import TorchChecker
+from ..check_env.transformers import TransformersChecker
+
+
+class EngineChecker(BaseChecker):
+ """check transformers is available."""
+
+ def __init__(self,
+ model_path: str,
+ engine_config: PytorchEngineConfig,
+ trust_remote_code: bool = True,
+ logger=None):
+ super().__init__(logger)
+ logger = self.get_logger()
+
+ self.engine_config = engine_config
+
+ dtype = engine_config.dtype
+ device_type = engine_config.device_type
+
+ # pytorch
+ torch_checker = TorchChecker(logger=logger)
+
+ if device_type == 'cuda':
+ # triton
+ from ..check_env.triton import TritonChecker
+ triton_checker = TritonChecker(logger=logger)
+ triton_checker.register_required_checker(torch_checker)
+ self.register_required_checker(triton_checker)
+ else:
+ # deeplink
+ from ..check_env.deeplink import DeeplinkChecker
+ dl_checker = DeeplinkChecker(device_type, logger=logger)
+ self.register_required_checker(dl_checker)
+ self.register_required_checker(torch_checker)
+
+ # transformers
+
+ # model
+ trans_checker = TransformersChecker()
+ model_checker = ModelChecker(model_path=model_path,
+ trust_remote_code=trust_remote_code,
+ dtype=dtype,
+ device_type=device_type,
+ logger=logger)
+ model_checker.register_required_checker(torch_checker)
+ model_checker.register_required_checker(trans_checker)
+ self.register_required_checker(model_checker)
+
+ # adapters
+ adapters = engine_config.adapters
+ if adapters is not None:
+ adapter_paths = list(adapters.values())
+ for adapter in adapter_paths:
+ adapter_checker = AdapterChecker(adapter, logger=logger)
+ self.register_required_checker(adapter_checker)
+
+ def check(self):
+ """check."""
+ engine_config = self.engine_config
+ logger = self.get_logger()
+
+ if engine_config.thread_safe:
+ logger.warning('thread safe mode has been deprecated and'
+ ' it would be removed in the future.')
+
+ if engine_config.max_batch_size <= 0:
+ self.log_and_exit(
+ mod_name='Engine',
+ message='max_batch_size should be'
+ f' greater than 0, but got {engine_config.max_batch_size}')
diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py
index 455ab1ccb3..dff9667eb4 100644
--- a/lmdeploy/pytorch/engine/engine_instance.py
+++ b/lmdeploy/pytorch/engine/engine_instance.py
@@ -1,16 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import List
+from typing import Any, Dict, List
from lmdeploy.messages import EngineOutput, GenerationConfig
from lmdeploy.utils import get_logger
-from ..messages import (InputEmbeddingRangeType, InputEmbeddings,
- InputEmbeddingType, SamplingParam)
+from ..messages import SamplingParam
from .engine import Engine
from .request import RequestSender, RequestType, Response, ResponseType
logger = get_logger('lmdeploy')
+InputMultiModalType = List[Dict[str, Any]]
+
def _check_resp(resp: Response, state: ResponseType, warning_msg: str = None):
"""check if response has state."""
@@ -114,15 +115,13 @@ def _try_add_session(self, session_id: int):
"""
return try_add_session(self.req_sender, session_id)
- async def async_stream_infer(
- self,
- session_id: int,
- input_ids: List[int],
- gen_config: GenerationConfig = None,
- adapter_name: str = None,
- input_embeddings: InputEmbeddingType = None,
- input_embedding_ranges: InputEmbeddingRangeType = None,
- **kwargs):
+ async def async_stream_infer(self,
+ session_id: int,
+ input_ids: List[int],
+ gen_config: GenerationConfig = None,
+ multimodal: InputMultiModalType = None,
+ adapter_name: str = None,
+ **kwargs):
"""Send stream inference request.
Args:
@@ -144,21 +143,13 @@ async def async_stream_infer(
await self.req_sender.async_send_async(
RequestType.ADD_SESSION, dict(session_id=session_id,
response=False))
- input_embeddings_new: List[InputEmbeddings] = None
- if input_embeddings is not None and len(input_embeddings) > 0:
- assert len(input_embeddings) == len(input_embedding_ranges)
- input_embeddings_new = [
- InputEmbeddings(emb, rg[0], rg[1])
- for emb, rg in zip(input_embeddings, input_embedding_ranges)
- ]
- msg = dict(token_ids=input_ids,
- session_id=session_id,
- sampling_param=sampling_param,
- adapter_name=adapter_name,
- input_embeddings=input_embeddings_new,
- mrope_position_ids=kwargs.get('mrope_position_ids'),
- mrope_position_delta=kwargs.get('mrope_position_delta'),
- cross_attention_states=kwargs.get('cross_attention_states'))
+ msg = dict(
+ token_ids=input_ids,
+ session_id=session_id,
+ sampling_param=sampling_param,
+ adapter_name=adapter_name,
+ input_multimodals=multimodal,
+ )
req_id = await self.req_sender.async_send_async(
RequestType.ADD_MESSAGE, msg)
@@ -179,14 +170,12 @@ async def async_stream_infer(
yield EngineOutput(resp.type, [], 0)
break
- async def async_infer(
- self,
- session_id: int,
- input_ids: List[int] = None,
- gen_config: GenerationConfig = None,
- input_embeddings: InputEmbeddingType = None,
- input_embedding_ranges: InputEmbeddingRangeType = None,
- **kwargs):
+ async def async_infer(self,
+ session_id: int,
+ input_ids: List[int] = None,
+ multimodal: InputMultiModalType = None,
+ gen_config: GenerationConfig = None,
+ **kwargs):
"""Send inference request.
Args:
@@ -200,13 +189,11 @@ async def async_infer(
int: The number of the output tokens.
"""
token_ids = []
- async for outputs in self.async_stream_infer(
- session_id,
- input_ids,
- gen_config=gen_config,
- input_embeddings=input_embeddings,
- input_embedding_ranges=input_embedding_ranges,
- **kwargs):
+ async for outputs in self.async_stream_infer(session_id,
+ input_ids,
+ multimodal=multimodal,
+ gen_config=gen_config,
+ **kwargs):
status, tmp_ids = outputs.status, outputs.token_ids
if status not in [ResponseType.SUCCESS, ResponseType.FINISH]:
return EngineOutput(status, token_ids, len(token_ids))
@@ -217,10 +204,9 @@ async def async_infer(
def stream_infer(self,
session_id: int,
input_ids: List[int],
+ multimodal: InputMultiModalType = None,
gen_config: GenerationConfig = None,
adapter_name: str = None,
- input_embeddings: InputEmbeddingType = None,
- input_embedding_ranges: InputEmbeddingRangeType = None,
**kwargs):
"""Send stream inference request.
@@ -241,14 +227,12 @@ def stream_infer(self,
def __call_async():
"""call async."""
- coro_gen = self.async_stream_infer(
- session_id,
- input_ids,
- gen_config,
- adapter_name,
- input_embeddings=input_embeddings,
- input_embedding_ranges=input_embedding_ranges,
- **kwargs)
+ coro_gen = self.async_stream_infer(session_id,
+ input_ids,
+ multimodal=multimodal,
+ gen_config=gen_config,
+ adapter_name=adapter_name,
+ **kwargs)
while True:
try:
yield self.req_sender.run_until_complete(
@@ -264,19 +248,12 @@ def __call_async():
sampling_param = SamplingParam.from_gen_config(gen_config=gen_config)
self.req_sender.send_async(RequestType.ADD_SESSION,
dict(session_id=session_id, response=False))
- input_embeddings_new: List[InputEmbeddings] = None
- if input_embeddings is not None and len(input_embeddings) > 0:
- assert len(input_embeddings) == len(input_embedding_ranges)
- input_embeddings_new = [
- InputEmbeddings(emb, rg[0], rg[1])
- for emb, rg in zip(input_embeddings, input_embedding_ranges)
- ]
msg = dict(
token_ids=input_ids,
session_id=session_id,
sampling_param=sampling_param,
adapter_name=adapter_name,
- input_embeddings=input_embeddings_new,
+ input_multimodals=multimodal,
)
req_id = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg)
@@ -300,9 +277,8 @@ def __call_async():
def infer(self,
session_id: int,
input_ids: List[int] = None,
+ multimodal: InputMultiModalType = None,
gen_config: GenerationConfig = None,
- input_embeddings: InputEmbeddingType = None,
- input_embedding_ranges: InputEmbeddingRangeType = None,
**kwargs):
"""Send inference request.
@@ -317,13 +293,11 @@ def infer(self,
int: The number of the output tokens.
"""
token_ids = []
- for outputs in self.stream_infer(
- session_id,
- input_ids,
- gen_config=gen_config,
- input_embeddings=input_embeddings,
- input_embedding_ranges=input_embedding_ranges,
- **kwargs):
+ for outputs in self.stream_infer(session_id,
+ input_ids,
+ multimodal=multimodal,
+ gen_config=gen_config,
+ **kwargs):
status, tmp_ids = outputs.status, outputs.token_ids
if status not in [ResponseType.SUCCESS, ResponseType.FINISH]:
return EngineOutput(status, token_ids, len(token_ids))
@@ -331,127 +305,6 @@ def infer(self,
return EngineOutput(0, token_ids, len(token_ids))
- async def async_batched_infer(
- self,
- session_ids: List[int],
- token_ids: List[List[int]] = None,
- gen_config: GenerationConfig = None,
- adapter_names: List[str] = None,
- keep_cache: bool = False,
- input_embeddings: List[InputEmbeddingType] = None,
- input_embedding_ranges: List[InputEmbeddingRangeType] = None,
- ):
- """Send inference request.
-
- Args:
- session_ids (List[int]): The session id.
- token_ids (List[int]): The input token ids.
- gen_config (GenerationConfig): The sampling parameters.
- adapter_names (List[str]): The name of the adapters.
- keep_cache (bool): Keep kv cache after infer.
-
- Returns:
- int: Error flags. 0 if success.
- List[int]: The streaming output tokens.
- int: The number of the output tokens.
- """
- batch_size = len(token_ids)
- assert len(session_ids) == batch_size
- if adapter_names is not None:
- assert len(adapter_names) == batch_size
- else:
- adapter_names = [None for _ in range(batch_size)]
-
- if input_embeddings is not None:
- assert len(input_embeddings) == batch_size
- assert len(input_embedding_ranges) == batch_size
- else:
- input_embeddings = [None] * batch_size
- input_embedding_ranges = [None] * batch_size
-
- async def _add_sessions(session_ids):
- for session_id in session_ids:
- await self._async_try_add_session(session_id)
-
- async def _add_messages(session_ids, token_ids, adapter_names,
- input_embeddings, input_embedding_ranges):
- add_msgs = []
- sampling_param = SamplingParam.from_gen_config(gen_config)
- for session_id, token_id, adapter_name, input_emb, input_ranges in zip( # noqa: E501
- session_ids, token_ids, adapter_names, input_embeddings,
- input_embedding_ranges):
- cur_input_embeddings: List[InputEmbeddings] = None
- if input_emb is not None and len(input_emb) > 0:
- assert len(input_emb) == len(input_ranges)
- cur_input_embeddings = [
- InputEmbeddings(emb, rg[0], rg[1])
- for emb, rg in zip(input_emb, input_ranges)
- ]
- msg = dict(
- token_ids=token_id,
- session_id=session_id,
- sampling_param=sampling_param,
- adapter_name=adapter_name,
- input_embeddings=cur_input_embeddings,
- )
- add_msgs.append(msg)
- req_types = [RequestType.ADD_MESSAGE] * batch_size
- req_ids = await self.req_sender.async_batched_send_async(
- req_types, data=add_msgs)
- return req_ids
-
- await _add_sessions(session_ids)
- req_ids = await _add_messages(session_ids, token_ids, adapter_names,
- input_embeddings, input_embedding_ranges)
-
- # receive messages
- req_idx_map = dict(zip(req_ids, range(len(req_ids))))
- output_token_ids = [list() for _ in req_ids]
- status = 0
- finish_count = batch_size
- while finish_count:
- resp = await self.req_sender.async_recv_any()
- if resp.req_id not in req_ids:
- continue
- idx = req_idx_map[resp.req_id]
- token_ids = output_token_ids[idx]
- if resp.type == ResponseType.SUCCESS:
- token_ids += resp.data['token_ids']
- elif resp.type == ResponseType.FINISH:
- token_ids += resp.data['token_ids']
- if not keep_cache:
- session_id = session_ids[idx]
- await self.async_end(session_id=session_id)
- finish_count -= 1
- else:
- logger.error(f'Unexpected response: {resp.type}')
- status = 1
- break
-
- output_token_len = [len(token_ids) for token_ids in output_token_ids]
- return EngineOutput(status, output_token_ids, output_token_len)
-
- def batched_infer(
- self,
- session_ids: List[int],
- token_ids: List[List[int]] = None,
- gen_config: GenerationConfig = None,
- adapter_names: List[str] = None,
- keep_cache: bool = False,
- input_embeddings: List[InputEmbeddingType] = None,
- input_embedding_ranges: List[InputEmbeddingRangeType] = None,
- ):
- """batched infer."""
- coro = self.async_batched_infer(
- session_ids,
- token_ids,
- gen_config=gen_config,
- adapter_names=adapter_names,
- input_embeddings=input_embeddings,
- input_embedding_ranges=input_embedding_ranges,
- keep_cache=keep_cache)
- return self.req_sender.run_until_complete(coro)
-
async def async_end(self, session_id: int):
"""End the given session."""
return await async_end(self.req_sender, session_id)
@@ -470,8 +323,7 @@ def cancel(self, session_id: int):
def decode(self,
input_ids,
- input_embeddings: List[InputEmbeddingType] = None,
- input_embedding_ranges: List[InputEmbeddingRangeType] = None,
+ multimodal: List[InputMultiModalType] = None,
steps: List[int] = None,
sequence_start: bool = True,
sequence_end: bool = True,
@@ -481,10 +333,8 @@ def decode(self,
Args:
input_ids (numpy.ndarray): the batch of input token ids
steps (List[int]): the offset of the k/v cache
- input_embeddings (List[List[Union[torch.Tensor, np.ndarray]]]):
- embeddings features
- input_embedding_ranges: (List[List[Tuple[int, int]]]):
- the begin/end offsets of input_embeddings to input_ids
+ multimodal (List[InputMultiModalType]):
+ multimodals inputs.
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
adapter_names (List[str]): The name of the adapters.
@@ -494,33 +344,24 @@ def decode(self,
batch_size = len(input_ids)
def __add_messages(session_ids, input_ids, adapter_names,
- input_embeddings, input_embedding_ranges):
+ input_multimodals):
add_msgs = []
sampling_param = SamplingParam(max_new_tokens=0)
batch_size = len(input_ids)
- if input_embeddings is None:
- input_embeddings = [None] * batch_size
- input_embedding_ranges = [None] * batch_size
- for (session_id, token_id, adapter_name, input_emb,
- input_ranges) in zip(session_ids, input_ids, adapter_names,
- input_embeddings,
- input_embedding_ranges):
+ if input_multimodals is None:
+ input_multimodals = [None] * batch_size
+ for (session_id, token_id, adapter_name,
+ in_mm) in zip(session_ids, input_ids, adapter_names,
+ input_multimodals):
if len(token_id) > self.max_input_len:
raise RuntimeError(
f'Expect input length<={self.max_input_len} '
f'but get {len(token_id)}')
- cur_input_embeddings: List[InputEmbeddings] = None
- if input_emb is not None and len(input_emb) > 0:
- assert len(input_emb) == len(input_ranges)
- cur_input_embeddings = [
- InputEmbeddings(emb, rg[0], rg[1])
- for emb, rg in zip(input_emb, input_ranges)
- ]
msg = dict(token_ids=token_id,
session_id=session_id,
sampling_param=sampling_param,
adapter_name=adapter_name,
- input_embeddings=cur_input_embeddings,
+ input_multimodals=in_mm,
return_logits=True)
add_msgs.append(msg)
req_types = [RequestType.ADD_MESSAGE] * batch_size
@@ -536,13 +377,6 @@ def __add_messages(session_ids, input_ids, adapter_names,
else:
adapter_names = [None] * batch_size
- if input_embeddings is not None:
- assert len(input_embeddings) == batch_size
- assert len(input_embedding_ranges) == batch_size
- else:
- input_embeddings = [None] * batch_size
- input_embedding_ranges = [None] * batch_size
-
session_ids = tuple(range(batch_size))
if sequence_start:
for sid in session_ids:
@@ -551,7 +385,7 @@ def __add_messages(session_ids, input_ids, adapter_names,
self._try_add_session(sid)
req_ids = __add_messages(session_ids, input_ids, adapter_names,
- input_embeddings, input_embedding_ranges)
+ multimodal)
req_idx_map = dict(zip(req_ids, range(len(req_ids))))
finish_count = batch_size
diff --git a/lmdeploy/pytorch/engine/input_process.py b/lmdeploy/pytorch/engine/input_process.py
new file mode 100644
index 0000000000..7f442e153b
--- /dev/null
+++ b/lmdeploy/pytorch/engine/input_process.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional
+
+from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs
+
+TypeModelMetas = Dict[str, Any]
+
+InputMultiModalType = List[Dict[str, Any]]
+
+
+@dataclass
+class PreprocessInputResult:
+ """results of preprocess input."""
+ input_ids: List[int]
+ input_multimodals: Optional[MultiModalInputs] = None
+ model_metas: Optional[TypeModelMetas] = None
+
+
+class BaseModelInputProcessor(ABC):
+ """processor of model inputs."""
+
+ @abstractmethod
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_mms: InputMultiModalType = None,
+ **kwargs) -> PreprocessInputResult:
+ """preprocess input."""
+ raise NotImplementedError('Not implemented.')
+
+
+class DefaultModelInputProcessor(BaseModelInputProcessor):
+ """default model input processor."""
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_mms: MultiModalInputs = None,
+ **kwargs) -> PreprocessInputResult:
+ """preprocess input."""
+ return PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=input_mms,
+ )
diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py
index 59d77f264a..5487639d29 100644
--- a/lmdeploy/pytorch/engine/model_agent.py
+++ b/lmdeploy/pytorch/engine/model_agent.py
@@ -135,21 +135,26 @@ def model_forward(
stream = stream or torch.cuda.current_stream()
with torch.cuda.stream(stream):
# forward
- inputs = inputs.to_device('cuda')
ctx_mgr = model.ctx_mgr
context = ctx_mgr.build_context(
inputs=inputs,
+ model_config=cache_engine.model_config,
world_size=world_size,
kv_caches=cache_engine.gpu_cache,
kv_quant_policy=cache_engine.cache_config.quant_policy,
)
with ctx_mgr.context(context):
+ model_metas = None
+ model_metas = model.update_model_metas(
+ past_key_values=cache_engine.gpu_cache,
+ context=context,
+ )
input_dict = model.prepare_inputs_for_generation(
past_key_values=cache_engine.gpu_cache,
context=context,
)
output = model(**input_dict)
- return dict(hidden_states=output)
+ return dict(hidden_states=output, model_metas=model_metas)
SwapMap = Dict[int, int]
@@ -177,6 +182,10 @@ def get_logits(self, hidden_states: torch.Tensor):
"""get logits of model output."""
raise NotImplementedError('Not implemented.')
+ def get_input_processor(self):
+ """get input processor."""
+ raise NotImplementedError('Not implemented.')
+
class BaseModelAgent(AutoModelAgent):
"""Base model agent.
@@ -267,14 +276,16 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
await asyncio.sleep(0)
- while not self.stream.query():
- await asyncio.sleep(0)
return output
def get_logits(self, hidden_states: torch.Tensor):
"""get logits of model output."""
return self.patched_model.get_logits(hidden_states)
+ def get_input_processor(self):
+ """get input processor.."""
+ return self.patched_model.get_input_processor()
+
@torch.inference_mode()
def _tp_build_model(
@@ -360,14 +371,26 @@ def _broadcast_config(cache_config):
return patched_model, cache_engine, cache_config
-def _broadcast_inputs(rank: int, inputs: Any, stream: torch.cuda.Stream):
+def _broadcast_inputs(rank: int, inputs: Any, group: dist.group,
+ stream: torch.cuda.Stream):
"""get input tensor parallel."""
# broadcast meta info
if rank != 0:
inputs = [None, None, None]
+ else:
+ device_inputs = inputs[0]
+ meta_inputs = device_inputs.to_device('meta')
+ inputs[0] = meta_inputs
with torch.cuda.stream(stream):
- dist.broadcast_object_list(inputs)
+ dist.broadcast_object_list(inputs, group=group)
+ if rank == 0:
+ device_inputs.broadcast()
+ else:
+ device_inputs = inputs[0].broadcast()
+
+ inputs[0] = device_inputs
+
return inputs
@@ -380,6 +403,7 @@ def _tp_model_loop(
adapters: Dict[str, str],
world_size: int,
barrier: mp.Barrier,
+ cpu_group: dist.group,
):
"""Start model loops for tensor parallel model inference.
@@ -405,11 +429,12 @@ def _tp_model_loop(
while True:
barrier.wait()
inputs, swap_in_map, swap_out_map = _broadcast_inputs(
- rank, None, stream)
+ rank, None, cpu_group, stream)
cache_swapping(cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
+ inputs = inputs.to_device('cuda')
model_forward(
patched_model,
@@ -441,10 +466,13 @@ def _start_tp_process(proc_id: int,
try:
from lmdeploy.pytorch.check_env import check_env_deeplink
check_env_deeplink(device_context.device_type)
+ timeout = timedelta(days=35600)
dist.init_process_group('nccl',
rank=rank,
world_size=world_size,
- timeout=timedelta(days=35600))
+ timeout=timeout)
+ cpu_group = dist.new_group(timeout=timeout, backend='gloo')
+ kwargs['cpu_group'] = cpu_group
dist_ctx = DistContext(rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
with get_dist_manager().context(dist_ctx), get_device_manager(
@@ -614,12 +642,15 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig,
rank = 0
try:
+ timeout = timedelta(days=35600)
dist.init_process_group('nccl',
rank=rank,
world_size=world_size,
- timeout=timedelta(days=35600))
+ timeout=timeout)
+ cpu_group = dist.new_group(timeout=timeout, backend='gloo')
dist_ctx = DistContext(rank=rank, world_size=world_size)
self._dist_ctx = dist_ctx
+ self._cpu_group = cpu_group
except Exception as e:
from traceback import print_exc
logger.error(f'Rank[{rank}] failed.')
@@ -661,7 +692,8 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap,
self.mp_bar.wait()
rank = 0
_broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map],
- self.stream)
+ self._cpu_group, self.stream)
+
cache_swapping(self.cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
@@ -687,14 +719,16 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
await asyncio.sleep(0)
- while not self.stream.query():
- await asyncio.sleep(0)
return output
def get_logits(self, hidden_states: torch.Tensor):
"""get logits of model output."""
return self.patched_model.get_logits(hidden_states)
+ def get_input_processor(self):
+ """get input processor.."""
+ return self.patched_model.get_input_processor()
+
def _exit_handler(agent: TPModelAgent):
if hasattr(agent, 'patched_model'):
@@ -722,7 +756,7 @@ def build_model_agent(model_path: str,
custom_module_map (str): customized nn module map
"""
model_config = ModelConfig.from_pretrained(
- model_path, trust_remote_code=trust_remote_code, dtype=dtype)
+ model_path, trust_remote_code=trust_remote_code, dtype=dtype, tp=tp)
model_config.custom_module_map = custom_module_map
if tp == 1:
model_agent = BaseModelAgent(model_path,
diff --git a/lmdeploy/pytorch/kernels/cuda/flashattention.py b/lmdeploy/pytorch/kernels/cuda/flashattention.py
index 34a11ae030..3d07225e43 100644
--- a/lmdeploy/pytorch/kernels/cuda/flashattention.py
+++ b/lmdeploy/pytorch/kernels/cuda/flashattention.py
@@ -47,6 +47,17 @@ def softcapping(qk, logit_softcapping: tl.constexpr):
return qk
+@triton.jit
+def _load_kv(ptrs, causal_mask: tl.constexpr, boundary_check: tl.constexpr):
+ """load kv."""
+ if causal_mask:
+ return tl.load(ptrs,
+ boundary_check=boundary_check,
+ padding_option='zero')
+ else:
+ return tl.load(ptrs)
+
+
@triton.jit
def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs,
loop_start, loop_end, sm_scale, history_mask,
@@ -63,11 +74,11 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs,
for start_n in range(loop_start, loop_end, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
- k = tl.load(k_ptrs)
+ k = _load_kv(k_ptrs, causal_mask, boundary_check=(1, ))
qk = tl.dot(q, k)
if BLOCK_DK1 != 0:
- k1 = tl.load(k1_ptrs)
+ k1 = _load_kv(k1_ptrs, causal_mask, boundary_check=(1, ))
qk += tl.dot(q1, k1)
if causal_mask:
@@ -117,7 +128,7 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs,
acc = acc * alpha[:, None]
# update acc
- v = tl.load(v_ptrs)
+ v = _load_kv(v_ptrs, causal_mask, boundary_check=(0, ))
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
@@ -172,6 +183,7 @@ def _flash_prefill_fwd_kernel(
kv_group_num,
head_dim_k,
head_dim_v,
+ causal: tl.constexpr,
window_size: tl.constexpr,
logit_softcapping: tl.constexpr,
BLOCK_M: tl.constexpr,
@@ -260,9 +272,13 @@ def _flash_prefill_fwd_kernel(
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
- history_mask = history_len + start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ if causal:
+ history_mask = history_len + start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ loop_end = (history_len + start_m * BLOCK_M) // BLOCK_N * BLOCK_N
+ else:
+ history_mask = tl.full([BLOCK_M], kv_seqlen - 1, dtype=tl.int32)
+ loop_end = kv_seqlen // BLOCK_N * BLOCK_N
- loop_end = (history_len + start_m * BLOCK_M) // BLOCK_N * BLOCK_N
acc, l_i, m_i = _prefill_fwd_inner(acc,
l_i,
m_i,
@@ -283,7 +299,10 @@ def _flash_prefill_fwd_kernel(
BLOCK_DK1=BLOCK_DK1)
loop_start = loop_end
- loop_end = tl.minimum(kv_seqlen, loop_start + BLOCK_M + BLOCK_N)
+ if causal:
+ loop_end = tl.minimum(kv_seqlen, loop_start + BLOCK_M + BLOCK_N)
+ else:
+ loop_end = kv_seqlen
acc, l_i, m_i = _prefill_fwd_inner(acc,
l_i,
m_i,
@@ -333,6 +352,7 @@ def flash_attention_fwd(
window_size: int = None,
sm_scale: float = None,
logit_softcapping: float = None,
+ causal: bool = True,
kv_layout: str = 'hsd',
):
"""varlen flash Attention forward.
@@ -383,6 +403,7 @@ def grid(args):
BLOCK_M = max(16, 8192 // BLOCK_DK)
else:
BLOCK_M = max(16, 16384 // BLOCK_DK)
+ BLOCK_M = min(128, BLOCK_M)
num_warps = 4
num_stages = min(4, max(2, 1024 // BLOCK_DK))
if BLOCK_DK >= 512:
@@ -416,6 +437,7 @@ def grid(args):
kv_group_num=kv_group_num,
head_dim_k=head_dim_k,
head_dim_v=head_dim_v,
+ causal=causal,
window_size=window_size,
logit_softcapping=logit_softcapping,
BLOCK_DK=BLOCK_DK,
diff --git a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py
index 90b135743e..3a77164046 100644
--- a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py
+++ b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py
@@ -31,7 +31,7 @@ def _flatten_kv_cache(
stride_vos: tl.constexpr,
stride_vod: tl.constexpr,
stride_boff,
- OUT_SIZE: tl.constexpr,
+ OUT_SIZE,
HEAD_DIM_K: tl.constexpr,
HEAD_DIM_V: tl.constexpr,
BLOCK_BS: tl.constexpr,
@@ -124,7 +124,7 @@ def _flatten_kv_cache_quant(
stride_vod: tl.constexpr,
stride_boff,
quant_policy: tl.constexpr,
- OUT_SIZE: tl.constexpr,
+ OUT_SIZE,
HEAD_DIM_K: tl.constexpr,
HEAD_DIM_V: tl.constexpr,
BLOCK_BS: tl.constexpr,
diff --git a/lmdeploy/pytorch/kernels/dlinfer/__init__.py b/lmdeploy/pytorch/kernels/dlinfer/__init__.py
index 8f86f0019a..fe82010761 100644
--- a/lmdeploy/pytorch/kernels/dlinfer/__init__.py
+++ b/lmdeploy/pytorch/kernels/dlinfer/__init__.py
@@ -3,6 +3,7 @@
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .awq_kernels import awq_linear
from .fill_kv_cache import fill_kv_cache
+from .flash_attention import flash_attention_fwd
from .fused_moe import fused_moe
from .linear import linear
from .moe_gating_topk_softmax import moe_gating_topk_softmax
@@ -16,6 +17,7 @@
'fill_kv_cache',
'fused_moe',
'paged_attention_fwd',
+ 'flash_attention_fwd',
'linear',
'moe_gating_topk_softmax',
'multinomial_sampling',
diff --git a/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py b/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py
index fb2eee9d41..63564d7ed8 100644
--- a/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py
+++ b/lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Sequence
+
import dlinfer.ops as ext_ops
from torch import Tensor
@@ -9,7 +11,16 @@ def fill_kv_cache(
key_caches: Tensor,
value_caches: Tensor,
kv_start_indices: Tensor,
+ k_scales_zeros: Sequence[Optional[Tensor]],
+ v_scales_zeros: Sequence[Optional[Tensor]],
+ quant_bits: int = 0,
):
"""fill key/value state to cache for paged attention."""
- return ext_ops.fill_kv_cache(key_states, value_states, key_caches,
- value_caches, kv_start_indices)
+ return ext_ops.fill_kv_cache(key_states,
+ value_states,
+ key_caches,
+ value_caches,
+ kv_start_indices,
+ k_scales_zeros=k_scales_zeros,
+ v_scales_zeros=v_scales_zeros,
+ quant_bits=quant_bits)
diff --git a/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py b/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py
new file mode 100644
index 0000000000..1788f947ee
--- /dev/null
+++ b/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import dlinfer.ops as ext_ops
+from dlinfer.utils.type_annotation import Tensor
+
+
+def flash_attention_fwd(
+ query_states: Tensor,
+ key_states: Tensor,
+ value_states: Tensor,
+ attn_output: Tensor,
+ q_start_loc: Tensor,
+ q_seqlens: Tensor,
+ kv_start_loc: Tensor,
+ kv_seqlens: Tensor,
+ max_q_seqlen: int = None,
+ window_size: int = None,
+ sm_scale: float = None,
+ logit_softcapping: float = None,
+ causal: bool = True,
+):
+ num_q_heads = query_states.shape[1]
+ num_kv_heads = value_states.shape[1]
+ return ext_ops.prefill_attention(
+ query_states,
+ key_states,
+ value_states,
+ q_start_loc,
+ q_seqlens,
+ max_q_seqlen,
+ num_q_heads,
+ num_kv_heads,
+ attn_mask=None,
+ softmax_scale=sm_scale,
+ attn_output=attn_output,
+ )
diff --git a/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py b/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py
index 72bab2d720..275ea65261 100644
--- a/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py
+++ b/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py
@@ -5,12 +5,13 @@
def fused_moe(
hidden_states: Tensor,
- top_k: int,
- topk_ids: Tensor,
- topk_weights: Tensor,
gate_up_weights: Tensor,
down_weights: Tensor,
+ topk_weights: Tensor,
+ topk_ids: Tensor,
+ topk: int,
+ renormalize: bool,
):
- """ascend fused moe."""
- return ext_ops.fused_moe(hidden_states, top_k, topk_ids, topk_weights,
- gate_up_weights, down_weights)
+ """dlinfer fused moe."""
+ return ext_ops.fused_moe(hidden_states, gate_up_weights, down_weights,
+ topk_weights, topk_ids, topk, renormalize)
diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
index 47bcb0cfff..ded85d476d 100644
--- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
+++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
@@ -19,6 +19,9 @@ def prefill_attention(
block_size: int,
attn_mask: Sequence[Optional[Tensor]],
is_unpaged_prefill: Optional[bool],
+ kv_scales: Optional[Tensor],
+ kv_zeros: Optional[Tensor],
+ quant_bits: Optional[int],
) -> Tensor:
num_q_heads = query_states.shape[1]
num_kv_heads = value_states.shape[1]
@@ -53,11 +56,25 @@ def prefill_attention(
num_kv_heads,
attn_mask,
attn_output=attn_output,
+ kv_scales=kv_scales,
+ kv_zeros=kv_zeros,
+ quant_bits=quant_bits,
)
-def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
- max_kv_seq_len, block_offsets, block_size):
+def paged_token_attention(
+ q,
+ k_cache,
+ v_cache,
+ attn_output,
+ kv_seq_len,
+ max_kv_seq_len,
+ block_offsets,
+ block_size,
+ kv_scales: Optional[Tensor],
+ kv_zeros: Optional[Tensor],
+ quant_bits: Optional[int],
+):
num_q_heads, q_head_dim = q.shape[1:3]
num_kv_heads = k_cache.shape[-1] // q_head_dim
return ext_ops.paged_decode_attention(
@@ -71,6 +88,9 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
num_q_heads,
num_kv_heads,
attn_output=attn_output,
+ kv_scales=kv_scales,
+ kv_zeros=kv_zeros,
+ quant_bits=quant_bits,
)
@@ -91,6 +111,9 @@ def paged_attention_fwd(
block_size: int,
attn_mask: Sequence[Optional[Tensor]] = (),
is_unpaged_prefill: Optional[bool] = None,
+ kv_scales: Optional[Tensor] = None,
+ kv_zeros: Optional[Tensor] = None,
+ quant_bits: Optional[int] = 0,
):
if not is_decoding:
return prefill_attention(
@@ -108,6 +131,9 @@ def paged_attention_fwd(
block_size,
attn_mask,
is_unpaged_prefill,
+ kv_scales=kv_scales,
+ kv_zeros=kv_zeros,
+ quant_bits=quant_bits,
)
else:
return paged_token_attention(
@@ -119,4 +145,7 @@ def paged_attention_fwd(
max_kv_seq_len,
block_offsets,
block_size,
+ kv_scales=kv_scales,
+ kv_zeros=kv_zeros,
+ quant_bits=quant_bits,
)
diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py
index b16a78f1f4..0aaba98c94 100644
--- a/lmdeploy/pytorch/messages.py
+++ b/lmdeploy/pytorch/messages.py
@@ -8,6 +8,7 @@
from torch import Tensor
from lmdeploy.messages import GenerationConfig, LogitsProcessor
+from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs
from lmdeploy.utils import get_logger
from .block import LogicalTokenBlocks
@@ -205,10 +206,9 @@ def add_sequence(
sampling_param: SamplingParam = None,
adapter_name: str = None,
return_logits: bool = False,
- input_embeddings: List[InputEmbeddings] = None,
- mrope_position_ids: Tensor = None,
- mrope_position_delta: Tensor = None,
- cross_attention_states: Tensor = None) -> 'SchedulerSequence':
+ multimodals: MultiModalInputs = None,
+ input_embeddings: List[InputEmbeddings] = None
+ ) -> 'SchedulerSequence':
"""Add a new message."""
if isinstance(token_ids, Tensor):
token_ids = token_ids.numpy()
@@ -228,10 +228,8 @@ def add_sequence(
adapter_name=adapter_name,
arrive_time=time.time(),
history_embeddings=HistoryEmbeddings(input_embeddings),
+ history_multimodals=HistoryMultiModals(multimodals),
return_logits=return_logits,
- mrope_position_ids=mrope_position_ids,
- mrope_position_delta=mrope_position_delta,
- cross_attention_states=cross_attention_states,
)
self.sequences[seq.seq_id] = seq
if self.seq_manager is not None:
@@ -361,6 +359,66 @@ def copy(self):
return self.clone()
+class HistoryMultiModals:
+
+ def __init__(self, multimodals: MultiModalInputs):
+ if multimodals is None:
+ multimodals = dict()
+ self.multimodals = multimodals
+
+ def get_datas(self, start=0, end=-1):
+ """get multimodals from prompts position [start, end)."""
+ outs = dict()
+ test_range = range(start, end)
+ for modal_type, modal_datas in self.multimodals.items():
+ data = []
+ for modal_data in modal_datas:
+ if (modal_data.start not in test_range
+ and modal_data.end not in test_range):
+ continue
+ data.append(modal_data)
+ if len(data) > 0:
+ outs[modal_type] = data
+ return outs
+
+ def add_inputs(self, input_mms: MultiModalInputs):
+ """add new inputs."""
+ for modal_type, vals in input_mms.items():
+ if modal_type in self.multimodals:
+ self.multimodals[modal_type] += vals
+ else:
+ self.multimodals[modal_type] = vals
+
+ def empty(self):
+ if len(self.multimodals) == 0:
+ return 0
+
+ return all(len(vals) == 0 for vals in self.multimodals)
+
+ @staticmethod
+ def update_multimodals(input_mms: MultiModalInputs, prev_len: int):
+ """update multimodals."""
+ for vals in input_mms.values():
+ for val in vals:
+ val.start += prev_len
+ val.end += prev_len
+ return input_mms
+
+ def get_encoder_len(self, start=0, end=-1):
+ """get lens of encoder."""
+ test_range = range(start, end)
+ out_len = 0
+ for _, modal_datas in self.multimodals.items():
+ for modal_data in modal_datas:
+ if modal_data.encoder_len is None:
+ continue
+ if (modal_data.start not in test_range
+ and modal_data.end not in test_range):
+ continue
+ out_len += modal_data.encoder_len
+ return out_len
+
+
@dataclass
class SchedulerSequence:
"""Scheduler message."""
@@ -369,6 +427,8 @@ class SchedulerSequence:
history_cache: HistoryTokenIds = field(default_factory=HistoryTokenIds)
history_embeddings: HistoryEmbeddings = field(
default_factory=HistoryEmbeddings)
+ history_multimodals: HistoryMultiModals = field(
+ default_factory=HistoryMultiModals)
num_new_tokens: int = 0
sampling_param: SamplingParam = field(default_factory=SamplingParam)
logical_blocks: LogicalTokenBlocks = field(
@@ -382,10 +442,7 @@ class SchedulerSequence:
random_offsets: int = 0
_status: MessageStatus = field(default=MessageStatus.WAITING, init=False)
num_ignored_history: int = 0
- mrope_position_ids: Optional[Tensor] = None
- mrope_position_delta: Optional[int] = None
- cross_attention_states: Optional[Tensor] = None
- history_cross_kv_seqlens: int = 0
+ model_meta: Dict[str, Any] = None
def __post_init__(self):
"""post init."""
@@ -394,6 +451,10 @@ def __post_init__(self):
self._num_images: int = len(self.history_embeddings)
self._num_token_ids: int = len(self.history_cache)
+ self._num_history_cross: int = 0
+ self._num_cross: int = self.history_multimodals.get_encoder_len(
+ 0, self._num_token_ids)
+
@property
def block_size(self) -> int:
"""block size."""
@@ -464,6 +525,16 @@ def num_all_ids(self):
"""num all tokens."""
return self.history_len + self._num_token_ids
+ @property
+ def num_cross(self):
+ """num cross."""
+ return self._num_cross
+
+ @property
+ def num_history_cross(self):
+ """num history cross."""
+ return self._num_history_cross
+
@property
def num_blocks(self):
"""num blocks."""
@@ -489,22 +560,22 @@ def num_all_tokens(self):
def num_all_cross_tokens(self):
"""num of all cross tokens."""
- if self.cross_attention_states is None:
- self.history_cross_kv_seqlens = 0
- else:
- self.history_cross_kv_seqlens = self.cross_attention_states.shape[
- -2]
- return self.history_cross_kv_seqlens
+ return self._num_cross + self._num_history_cross
+
+ def get_input_multimodals(self):
+ """get input multimodals."""
+ start = self.num_history_ids
+ end = self.num_all_ids
+ return self.history_multimodals.get_datas(start, end)
def update_token_ids(self,
token_ids: Tensor,
+ multimodals: MultiModalInputs = None,
embeddings: List[InputEmbeddings] = None,
- cross_attention_states: List[Tensor] = None):
+ model_meta: Dict[str, Any] = None):
"""Update token ids, old token ids will be added to history."""
- # cross attention
- if cross_attention_states is not None:
- self.history_cross_kv_seqlens += cross_attention_states.shape[-2]
- self.cross_attention_states = cross_attention_states
+ old_num_history_ids = self._num_history_ids
+
self._num_history_ids += self._num_token_ids
# update history image nums
self._num_history_images += self._num_images
@@ -516,6 +587,23 @@ def update_token_ids(self,
self._num_images = len(new_embeddings)
self.history_embeddings.append(new_embeddings)
+ # update multimodals
+ if multimodals is not None:
+ multimodals = HistoryMultiModals.update_multimodals(
+ multimodals, self.num_all_ids)
+ self.history_multimodals.add_inputs(multimodals)
+
+ # cross
+ self._num_history_cross += self._num_cross
+ if multimodals is not None:
+ self._num_cross = self.history_multimodals.get_encoder_len(
+ old_num_history_ids, self._num_history_ids)
+ else:
+ self._num_cross = 0
+
+ if model_meta is not None:
+ self.model_meta = model_meta
+
if isinstance(token_ids, Tensor):
token_ids = token_ids.numpy()
elif not isinstance(token_ids, np.ndarray):
@@ -539,3 +627,12 @@ def set_step(self, step: int):
self._num_history_ids = step
self._num_token_ids = num_all_ids - step
self.num_ignored_history = min(step, self.num_ignored_history)
+
+ self.model_meta = None
+
+ # cross
+ if self.history_multimodals is not None:
+ self._num_history_cross = self.history_multimodals.get_encoder_len(
+ 0, self.num_history_ids)
+ self._num_cross = self.history_multimodals.get_encoder_len(
+ self._num_history_ids, num_all_ids)
diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py
index 669625d43d..d10da8557a 100644
--- a/lmdeploy/pytorch/model_inputs.py
+++ b/lmdeploy/pytorch/model_inputs.py
@@ -4,47 +4,19 @@
from typing import Any, Dict, List, Literal
import torch
+from torch import distributed as dist
from lmdeploy.pytorch.backends import get_backend
+from lmdeploy.pytorch.config import ModelConfig
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
-@dataclass
-class MRopeModelInputs:
- """Multimodal rotary position inputs."""
- position_ids: List[torch.LongTensor] = None
- deltas: List[torch.LongTensor] = None
-
- def get_inputs(self, history_lengths: torch.Tensor,
- seq_lengths: torch.Tensor):
- mrope_position_ids = []
- for (his_len, seq_len, pos_ids,
- delta) in zip(history_lengths, seq_lengths, self.position_ids,
- self.deltas):
- assert pos_ids.dim() == 2, 'invalid mrope_position_ids'
- if his_len + seq_len <= pos_ids.shape[1]:
- mrope_position_ids.append(pos_ids[:,
- his_len:his_len + seq_len])
- else:
- mrope_position_ids.append(
- torch.tensor([his_len], device=delta.device).expand(3, -1)
- + delta)
-
- mrope_position_ids = torch.cat(mrope_position_ids, dim=-1)
- return mrope_position_ids
-
- def to_device(self, device: str):
- """to device."""
- out_dict = dict()
- for f in fields(self):
- k = f.name
- v = getattr(self, k)
- if isinstance(v, torch.Tensor):
- v = v.to(device)
- elif isinstance(v, list):
- v = [x.to(device) for x in v]
- out_dict[k] = v
-
- return MRopeModelInputs(**out_dict)
+def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'):
+ """broadcast tensor."""
+ if value.device.type == 'meta':
+ value = torch.empty_like(value, device=device)
+ dist.broadcast(value, src)
+ return value
@dataclass
@@ -56,6 +28,7 @@ class VisionModelInputs:
input_embeddings: List[List[torch.Tensor]] = None
input_embedding_ranges: List[torch.LongTensor] = None
input_embedding_indexing: torch.BoolTensor = None
+ input_multimodals: List[MultiModalTensor] = None
def to_device(self, device: str):
"""to device."""
@@ -63,12 +36,54 @@ def to_device(self, device: str):
for f in fields(self):
k = f.name
v = getattr(self, k)
+ if v is None:
+ continue
if isinstance(v, torch.Tensor):
v = v.to(device)
- elif k == 'input_embedding_ranges' and v is not None:
+ elif k == 'input_embedding_ranges':
v = [e.to(device) for e in v]
- elif k == 'input_embeddings' and v is not None:
+ elif k == 'input_embeddings':
v = [[e.to(device) for e in li] for li in v]
+ elif k == 'input_multimodals':
+ new_v = []
+ for mm_datas in v:
+ new_mm_datas = dict()
+ for modal_type, data in mm_datas.items():
+ data = [d.to_device(device) for d in data]
+ new_mm_datas[modal_type] = data
+ new_v.append(new_mm_datas)
+ v = new_v
+ out_dict[k] = v
+
+ return VisionModelInputs(**out_dict)
+
+ def broadcast(self):
+ """broadcast inputs.
+
+ Do `dist.broadcast_object_list(inputs.to_device('meta'))`
+ before broadcast tensors.
+ """
+ out_dict = dict()
+ for f in fields(self):
+ k = f.name
+ v = getattr(self, k)
+ if v is None:
+ continue
+ if isinstance(v, torch.Tensor):
+ v = _broadcast_tensor(v)
+ elif k == 'input_embedding_ranges':
+ v = [_broadcast_tensor(e) for e in v]
+ elif k == 'input_embeddings':
+ v = [[_broadcast_tensor(e) for e in li] for li in v]
+ elif k == 'input_multimodals':
+ new_v = []
+ for mm_datas in v:
+ new_mm_datas = dict()
+ for modal_type, data in mm_datas.items():
+ data = [d.broadcast() for d in data]
+ new_mm_datas[modal_type] = data
+ new_v.append(new_mm_datas)
+ v = new_v
out_dict[k] = v
return VisionModelInputs(**out_dict)
@@ -119,9 +134,9 @@ class ModelInputs:
num_ignored_history: torch.LongTensor
local_adapter_ids: torch.LongTensor = None
vision_inputs: VisionModelInputs = None
- mrope_inputs: MRopeModelInputs = None
- cross_attention_states: torch.Tensor = None
- history_cross_kv_seqlens: torch.LongTensor = None
+ cross_length: torch.LongTensor = None
+ history_cross_length: torch.LongTensor = None
+ model_metas: List[Dict[str, Any]] = None
def update(self, input_ids: torch.LongTensor):
"""update input ids."""
@@ -132,44 +147,88 @@ def update(self, input_ids: torch.LongTensor):
self.input_ids = input_ids
return self
- def split(self, split_size: int, block_size: int):
+ def split(self, split_size: int):
"""split inputs."""
assert len(
self.seq_length) == 1, ('Can not perform split on batched input.')
- assert split_size % block_size == 0, (
- 'split_size should be multi of block_size.')
input_ids = self.input_ids
if input_ids.numel() < split_size:
return self
- num_blocks = split_size // block_size
- overlap = (self.history_lengths[0] % block_size != 0)
+ flatten_mms = []
+ vision_inputs = self.vision_inputs
+ if vision_inputs is not None:
+ if vision_inputs.input_multimodals is not None:
+ input_mms = vision_inputs.input_multimodals[0]
+
+ flatten_mms = []
+ for k, mms in input_mms.items():
+ mms = [(k, mm) for mm in mms]
+ flatten_mms += mms
+
+ flatten_mms = sorted(flatten_mms, key=lambda mm: mm[1].start)
+
max_seq_len = self.seq_length[0].item()
ret = []
- block_start = 0
- for i in range(0, max_seq_len, split_size):
- start = i
- end = min(max_seq_len, i + split_size)
- block_end = block_start + num_blocks
- if overlap:
- block_end += 1
-
- block_offsets = self.block_offsets
+ start = 0
+ history_cross_length = self.history_cross_length
+ cross_length = None
+ if history_cross_length is not None:
+ cross_length = self.history_cross_length.clone()
+ while start < max_seq_len:
+ vision_inputs = None
+ if len(flatten_mms) > 0:
+ mm_start = flatten_mms[0][1].start
+ mm_end = flatten_mms[0][1].end
+ if mm_start > self.history_lengths + start:
+ end = min(mm_start - self.history_lengths,
+ start + split_size)
+ else:
+ input_mms = dict()
+ key, mm = flatten_mms.pop(0)
+ input_mms.setdefault(key, [])
+ input_mms[key].append(mm)
+ end = start + mm.end - mm.start
+ while len(flatten_mms) > 0:
+ next_mm = flatten_mms[0]
+ next_start = next_mm[1].start
+ next_end = next_mm[1].end
+ if next_start < mm_end:
+ key = next_mm[0]
+ input_mms.setdefault(key, [])
+ input_mms[key].append(next_mm[1])
+ end += max(0, next_end - mm_end)
+ flatten_mms.pop(0)
+
+ if cross_length is not None:
+ encoder_len = next_mm[1].encoder_len
+ if encoder_len is not None:
+ cross_length += encoder_len
+ else:
+ break
+ vision_inputs = VisionModelInputs(
+ input_multimodals=[input_mms], )
+ else:
+ end = min(max_seq_len, start + split_size)
+
inp = ModelInputs(
input_ids=self.input_ids[:, start:end],
seq_length=input_ids.new_tensor([end - start]),
- block_offsets=block_offsets,
+ block_offsets=self.block_offsets,
history_lengths=self.history_lengths + start,
is_decoding=self.is_decoding,
num_ignored_history=self.num_ignored_history,
local_adapter_ids=self.local_adapter_ids,
- vision_inputs=self.vision_inputs,
- mrope_inputs=self.mrope_inputs,
- cross_attention_states=self.cross_attention_states,
+ vision_inputs=vision_inputs,
+ model_metas=self.model_metas,
+ cross_length=cross_length,
+ history_cross_length=history_cross_length,
)
ret.append(inp)
- block_start += num_blocks
+ history_cross_length = cross_length
+
+ start = end
return ret
@@ -183,8 +242,24 @@ def to_device(self, device: str):
v = v.to(device)
elif isinstance(v, VisionModelInputs):
v = v.to_device(device)
- elif isinstance(v, MRopeModelInputs):
- v = v.to_device(device)
+ out_dict[k] = v
+
+ return ModelInputs(**out_dict)
+
+ def broadcast(self):
+ """broadcast inputs.
+
+ Do `dist.broadcast_object_list(inputs.to_device('meta'))`
+ before broadcast tensors.
+ """
+ out_dict = dict()
+ for f in fields(self):
+ k = f.name
+ v = getattr(self, k)
+ if isinstance(v, torch.Tensor):
+ v = _broadcast_tensor(v)
+ elif isinstance(v, VisionModelInputs):
+ v = v.broadcast()
out_dict[k] = v
return ModelInputs(**out_dict)
@@ -198,6 +273,7 @@ class StepContext:
dataclass provide these infos and tools.
"""
input_ids: torch.LongTensor
+ model_config: ModelConfig
block_offsets: torch.LongTensor
position_ids: torch.LongTensor
attention_mask: torch.LongTensor
@@ -210,13 +286,14 @@ class StepContext:
local_adapter_ids: torch.LongTensor = None
input_embeddings: torch.Tensor = None
input_embedding_indexing: torch.Tensor = None
+ input_multimodals: List[MultiModalTensor] = None
vision_inputs: VisionModelInputs = None
- mrope_position_ids: torch.Tensor = None
attn_metadata: Any = None
- cross_attn_metadata: Any = None
- cross_attention_states: torch.Tensor = None
+ cross_seqlens: torch.LongTensor = None
cross_kv_seqlens: torch.LongTensor = None
+ cross_attn_metadata: Any = None
kv_quant_policy: Literal[0, 4, 8] = 0
+ model_metas: List[Dict[str, Any]] = None
_outputs: Dict = field(default_factory=dict)
@@ -224,6 +301,7 @@ class StepContext:
def new(
cls,
inputs: ModelInputs,
+ model_config: ModelConfig,
world_size: int = 1,
kv_caches: List = None,
kv_quant_policy: Literal[0, 4, 8] = 0,
@@ -239,24 +317,21 @@ def new(
history_seqlens = inputs.history_lengths
device = q_seqlens.device
+ input_multimodals = None
+ if inputs.vision_inputs is not None:
+ input_multimodals = inputs.vision_inputs.input_multimodals
+
# for vlm
input_embeddings, input_embedding_indexing = None, None
if (inputs.vision_inputs is not None
and inputs.vision_inputs.input_embeddings is not None):
input_embeddings, input_embedding_indexing = \
inputs.vision_inputs.get_inputs(history_seqlens, q_seqlens)
- # for mrope
- mrope_position_ids = None
- if inputs.mrope_inputs is not None:
- mrope_position_ids = inputs.mrope_inputs.get_inputs(
- history_seqlens, q_seqlens)
# kv_seqlens
- cross_attention_states = inputs.cross_attention_states
if inputs.is_decoding:
attention_mask = torch.ones_like(q_seqlens)[:, None]
- position_ids = history_seqlens.unsqueeze(-1)
- cross_attention_states = None
+ position_ids = history_seqlens.unsqueeze(-1).clone()
else:
max_q_seqlen = q_seqlens.max().item()
mask_range = torch.arange(max_q_seqlen, device=device)[None, :]
@@ -265,6 +340,13 @@ def new(
position_ids += history_seqlens.unsqueeze(-1)
q_start_loc = q_seqlens.cumsum(0) - q_seqlens
+ # cross
+ cross_seqlens = inputs.cross_length
+ cross_kv_seqlens = None
+ if inputs.cross_length is not None:
+ cross_kv_seqlens = (inputs.cross_length +
+ inputs.history_cross_length)
+
# position ids 1d
position_ids = cls.get_position_ids_1d(position_ids, q_seqlens)[None]
# seq_len + history_length
@@ -273,10 +355,12 @@ def new(
ret = StepContext(
input_ids=inputs.input_ids,
+ model_config=model_config,
block_offsets=inputs.block_offsets,
position_ids=position_ids,
input_embeddings=input_embeddings,
input_embedding_indexing=input_embedding_indexing,
+ input_multimodals=input_multimodals,
attention_mask=attention_mask,
q_seqlens=q_seqlens,
kv_seqlens=kv_seqlens,
@@ -286,10 +370,10 @@ def new(
world_size=world_size,
local_adapter_ids=inputs.local_adapter_ids,
vision_inputs=inputs.vision_inputs,
- mrope_position_ids=mrope_position_ids,
- cross_attention_states=cross_attention_states,
- cross_kv_seqlens=inputs.history_cross_kv_seqlens,
kv_quant_policy=kv_quant_policy,
+ model_metas=inputs.model_metas,
+ cross_seqlens=cross_seqlens,
+ cross_kv_seqlens=cross_kv_seqlens,
)
ret = get_backend().update_step_context(ret)
@@ -318,6 +402,7 @@ def __init__(self):
@staticmethod
def build_context(
inputs: ModelInputs,
+ model_config: ModelConfig,
world_size: int = 1,
kv_caches: List = None,
kv_quant_policy: Literal[0, 4, 8] = 0,
@@ -325,6 +410,7 @@ def build_context(
"""build context."""
return StepContext.new(
inputs,
+ model_config,
world_size,
kv_caches,
kv_quant_policy,
diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py
index 8d7a21a0a6..73f64d277c 100644
--- a/lmdeploy/pytorch/models/chatglm2.py
+++ b/lmdeploy/pytorch/models/chatglm2.py
@@ -1,101 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Any, Iterable, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
+from torch.nn import functional as F
from transformers.configuration_utils import PretrainedConfig
+from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor,
+ PreprocessInputResult)
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
SiluAndMul, build_rotary_embedding)
-from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear,
+from lmdeploy.pytorch.nn.linear import (build_colwise_linear,
+ build_merged_colwise_linear,
build_qkv_proj, build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .utils.cudagraph import CudaGraphMixin
+from .utils.model import DeployModelMixin
LANGUAGE_TOKEN_TYPE = 0
VISION_TOKEN_TYPE = 1
-def get_vision_expert_mask(token_type_ids: torch.LongTensor):
- vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
- vision_token_mask[:, :-1] = (token_type_ids[:, :-1]
- == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:]
- == VISION_TOKEN_TYPE)
- language_token_mask = ~vision_token_mask
- return vision_token_mask, language_token_mask
-
-
-def build_position_ids(x: torch.BoolTensor) -> torch.LongTensor:
- tmp = x.clone()
- # image boi eoi token as LANGUAGE_TOKEN_TYPE
- is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
- is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (
- tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
- is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
- is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (
- tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
- is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
- tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
- # final position ids
- y = torch.zeros_like(x, dtype=torch.long)
- y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
- (tmp[:, 1:] == VISION_TOKEN_TYPE) &
- (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
- y = y.cumsum(dim=-1)
- return y
-
-
-def _get_cogvlm_position_ids(context):
- """get cogvlm position_ids."""
- q_seqlens = context.q_seqlens
- history_lengths = context.kv_seqlens - q_seqlens
- vision_input_info = context.vision_inputs
- position_id_offsets = (vision_input_info.history_image_token_lengths -
- vision_input_info.history_image_nums * 3)
- lang_ids = None
- vis_ids = None
- if context.is_decoding:
- position_ids = history_lengths - position_id_offsets
- else:
- if vision_input_info.input_embeddings is not None and len(
- vision_input_info.input_embeddings) > 0:
- starts = history_lengths - vision_input_info.history_lengths
- ends = starts + q_seqlens
- token_type_ids = vision_input_info.input_embedding_indexing.to(
- torch.int)
- history_position_lengths = (vision_input_info.history_lengths -
- position_id_offsets)
- position_ids_all = (history_position_lengths[:, None] +
- build_position_ids(token_type_ids))
- position_ids = torch.cat([
- pids[s:e]
- for (pids, s, e) in zip(position_ids_all, starts, ends)
- ])
- vision_token_mask_all, _ = get_vision_expert_mask(token_type_ids)
- vision_token_mask = torch.cat([
- masks[s:e]
- for (masks, s, e) in zip(vision_token_mask_all, starts, ends)
- ])
- mask_indexing = torch.arange(vision_token_mask.shape[-1],
- device=vision_token_mask.device)
- vis_ids = mask_indexing[vision_token_mask]
- lang_ids = mask_indexing[~vision_token_mask]
-
- else:
- position_ids = context.attention_mask.long().cumsum(-1) - 1
- position_ids += (history_lengths -
- position_id_offsets).unsqueeze(-1)
- device = position_ids.device
- position_ids_1d = [
- ids[:l] for ids, l in zip(position_ids.cpu(), q_seqlens.cpu())
- ]
- position_ids = torch.cat(position_ids_1d).to(device)
-
- return position_ids, lang_ids, vis_ids
-
-
class SelfAttention(torch.nn.Module):
"""Parallel self-attention layer abstract class.
@@ -112,11 +40,10 @@ def __init__(self,
self.projection_size = config.kv_channels * config.num_attention_heads
self.num_attention_heads = config.num_attention_heads
- self.num_kv_heads = self.num_attention_heads
+ self.num_kv_heads = config.num_key_value_heads
self.head_size = (self.projection_size // config.num_attention_heads)
- self.multi_query_attention = config.multi_query_attention
- if self.multi_query_attention:
- self.num_kv_heads = config.multi_query_group_num
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
self.query_key_value = build_qkv_proj(
config.hidden_size,
num_q_heads=self.num_attention_heads,
@@ -126,7 +53,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
# apply rotary
self.apply_rotary_pos_emb = ApplyRotaryEmb()
@@ -410,6 +337,286 @@ def forward(self, input_ids):
return embeddings
+class PatchEmbedding(nn.Module):
+ """vision embedding."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.proj = nn.Conv2d(config.in_channels,
+ config.hidden_size,
+ kernel_size=config.patch_size,
+ stride=config.patch_size,
+ dtype=dtype,
+ device=device)
+ self.cls_embedding = nn.Parameter(
+ torch.empty(1, config.hidden_size, dtype=dtype, device=device))
+ self.position_embedding = nn.Embedding(config.num_positions,
+ config.hidden_size,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, images):
+ """forward."""
+ x = self.proj(images)
+ x = x.flatten(2).transpose(1, 2)
+ cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
+ x = torch.cat((cls_token, x), dim=1)
+ x += self.position_embedding.weight.unsqueeze(0)
+ return x
+
+
+class EVA2CLIPAttention(nn.Module):
+ """vision attention."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
+ hidden_size = config.hidden_size
+ num_heads = config.num_heads
+ head_dim = config.hidden_size // config.num_heads
+ self.scale = head_dim**-0.5
+
+ # packed qkv
+ self.query_key_value = build_qkv_proj(
+ hidden_size,
+ num_q_heads=num_heads,
+ num_kv_heads=num_heads,
+ head_size=head_dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ )
+
+ # o_proj
+ self.dense = build_rowwise_linear(hidden_size,
+ hidden_size,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """forward."""
+ # qkv proj
+ qkv_states = self.query_key_value(hidden_states)
+ q, k, v = self.query_key_value.split_qkv(qkv_states)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
+
+ # o proj
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.flatten(-2, -1)
+ attn_output = self.dense(attn_output)
+ return attn_output
+
+
+class EVA2CLIPMLP(nn.Module):
+ """vision MLP."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ from transformers.activations import ACT2FN
+
+ # gate up
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.fc1 = build_colwise_linear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+
+ # silu and mul
+ if config.hidden_act in [
+ 'gelu', 'gelu_fast', 'quick_gelu', 'gelu_python'
+ ]:
+ self.activation_fn = nn.GELU()
+ else:
+ self.activation_fn = ACT2FN[config.hidden_act]
+
+ # down
+ self.fc2 = build_rowwise_linear(config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """forward."""
+ x = self.fc1(x)
+ x = self.activation_fn(x)
+ x = self.fc2(x)
+ return x
+
+
+class EVA2CLIPTransformerLayer(nn.Module):
+ """vision trans layer."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.input_layernorm = nn.LayerNorm(config.hidden_size,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+ self.attention = EVA2CLIPAttention(config, dtype=dtype, device=device)
+ self.mlp = EVA2CLIPMLP(config, dtype=dtype, device=device)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, hidden_states):
+ """forward."""
+ attention_input = hidden_states
+ attention_output = self.input_layernorm(
+ self.attention(attention_input))
+ hidden_states = attention_input + attention_output
+ mlp_input = hidden_states
+ mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
+ output = mlp_input + mlp_output
+ return output
+
+
+class EVA2CLIPTransformer(nn.Module):
+ """vision transformer."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.layers = nn.ModuleList([
+ EVA2CLIPTransformerLayer(config, dtype=dtype, device=device)
+ for _ in range(config.num_hidden_layers)
+ ])
+
+ def forward(self, hidden_states):
+ """forward."""
+ for layer_module in self.layers:
+ hidden_states = layer_module(hidden_states)
+ return hidden_states
+
+
+class GLU(nn.Module):
+ """GLU."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ in_features: int,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.linear_proj = nn.Linear(in_features,
+ config.hidden_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+ self.norm1 = nn.LayerNorm(config.hidden_size,
+ dtype=dtype,
+ device=device)
+ self.act1 = nn.GELU()
+ self.act2 = nn.functional.silu
+ self.dense_h_to_4h = nn.Linear(config.hidden_size,
+ config.ffn_hidden_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+ self.gate_proj = nn.Linear(config.hidden_size,
+ config.ffn_hidden_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+ self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size,
+ config.hidden_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, x):
+ x = self.linear_proj(x)
+ x = self.act1(self.norm1(x))
+ x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
+ x = self.dense_4h_to_h(x)
+ return x
+
+
+class EVA2CLIPModel(nn.Module):
+ """vision model."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ from argparse import Namespace
+ vision_config = Namespace(**config.vision_config)
+
+ self.patch_embedding = PatchEmbedding(vision_config,
+ dtype=dtype,
+ device=device)
+ self.transformer = EVA2CLIPTransformer(vision_config,
+ dtype=dtype,
+ device=device)
+ self.linear_proj = GLU(config,
+ in_features=config.hidden_size,
+ dtype=dtype,
+ device=device)
+ self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
+ out_channels=config.hidden_size,
+ kernel_size=2,
+ stride=2,
+ dtype=dtype,
+ device=device)
+ self.boi = nn.Parameter(
+ torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))
+ self.eoi = nn.Parameter(
+ torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))
+ self.scaling_factor = vision_config.scaling_factor
+
+ def forward(self, images):
+ """forward."""
+ x = self.patch_embedding(images)
+ x = self.transformer(x)
+
+ x = x[:, 1:]
+
+ b, s, h = x.shape
+ grid_size = int(s**0.5)
+ x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
+ x = self.conv(x)
+
+ x = x.flatten(2).transpose(1, 2)
+ x = self.linear_proj(x)
+ boi = self.boi.expand(x.shape[0], -1, -1)
+ eoi = self.eoi.expand(x.shape[0], -1, -1)
+ x = torch.cat((boi, x, eoi), dim=1)
+ x = x / self.scaling_factor
+ return x
+
+
class ChatGLMModel(nn.Module):
def __init__(self,
@@ -442,19 +649,32 @@ def __init__(self,
dtype=dtype,
device=device)
+ self.vision = None
+ if hasattr(config, 'vision_config'):
+ self.vision = EVA2CLIPModel(config, dtype=dtype, device=device)
+
def forward(
self,
input_ids: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
attn_metadata: Any = None,
+ images: torch.Tensor = None,
+ image_mask: torch.Tensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
):
"""forward."""
# token embedding
if inputs_embeds is None:
+ images_features = None
+ if images is not None:
+ images_features = self.vision(images)
+ images_features = images_features.flatten(0, 1)[None]
inputs_embeds = self.embedding(input_ids)
+ if images is not None:
+ inputs_embeds.masked_scatter_(image_mask[..., None],
+ images_features)
hidden_states = inputs_embeds
@@ -477,7 +697,8 @@ def get_input_embeddings(self):
return self.embedding
-class ChatGLMForConditionalGeneration(nn.Module, CudaGraphMixin):
+class ChatGLMForConditionalGeneration(nn.Module, DeployModelMixin,
+ CudaGraphMixin):
"""rewrote model of LlamaForCausalLM."""
def __init__(self,
@@ -491,12 +712,16 @@ def __init__(self,
# build Model
self.transformer = ChatGLMModel(config, dtype=dtype, device=device)
+ self.input_processor = ChatGLMInputProcessor(self.config, dtype)
+
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: List[List[torch.Tensor]],
attn_metadata: Any = None,
+ images: torch.Tensor = None,
+ image_mask: torch.Tensor = None,
inputs_embeds: torch.Tensor = None,
**kwargs,
):
@@ -506,6 +731,8 @@ def forward(
position_ids=position_ids,
past_key_values=past_key_values,
attn_metadata=attn_metadata,
+ images=images,
+ image_mask=image_mask,
inputs_embeds=inputs_embeds,
)
return hidden_states
@@ -529,8 +756,23 @@ def prepare_inputs_for_generation(
input_ids = context.input_ids
position_ids = context.position_ids
attn_metadata = context.attn_metadata
- if context.vision_inputs is not None:
- position_ids = _get_cogvlm_position_ids(context)[0][None]
+
+ images = None
+ image_mask = None
+ if context.input_multimodals is not None:
+ images = [
+ input_mm.get('image', [])
+ for input_mm in context.input_multimodals
+ ]
+ # flatten batch
+ images = [data for im_data in images for data in im_data]
+ if len(images) != 0:
+ image_token_id = images[0].meta['image_token_id']
+ image_mask = input_ids == image_token_id
+ images = torch.stack([data.data for data in images])
+ else:
+ images = None
+ image_mask = None
# process vision embeddings
vision_embeddings = context.input_embeddings
@@ -548,9 +790,92 @@ def prepare_inputs_for_generation(
position_ids=position_ids,
past_key_values=past_key_values,
attn_metadata=attn_metadata,
+ images=images,
+ image_mask=image_mask,
inputs_embeds=inputs_embeds,
)
+ def update_model_metas(self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None):
+ """update model meta."""
+ model_metas = context.model_metas
+ if not hasattr(self.config, 'vision_config'):
+ return model_metas
+
+ input_multimodals = context.input_multimodals
+ if input_multimodals is None:
+ input_imgs = [[] for _ in model_metas]
+ else:
+ input_imgs = []
+ for mm in input_multimodals:
+ if mm is None:
+ input_imgs.append([])
+ else:
+ input_imgs.append(mm.get('image', []))
+
+ config = self.config
+ image_size: int = config.vision_config['image_size']
+ patch_size: int = config.vision_config['patch_size']
+ vision_token_num = ((image_size // patch_size // 2) *
+ (image_size // patch_size // 2) + 2)
+ num_pad = vision_token_num - 3
+
+ batched_num_img_tokens = []
+ new_model_metas = []
+ for meta, imgs in zip(model_metas, input_imgs):
+ if meta is None:
+ num_img_tokens = 0
+ else:
+ num_img_tokens = meta.get('num_img_tokens', 0)
+
+ batched_num_img_tokens.append(num_img_tokens)
+
+ num_img_tokens += num_pad * len(imgs)
+ new_model_metas.append(dict(num_img_tokens=num_img_tokens))
+
+ # prepare cogvlm position_ids
+ q_seqlens = context.q_seqlens
+ position_ids = context.position_ids
+
+ if context.is_decoding or all(len(imgs) == 0 for imgs in input_imgs):
+ num_img_tokens = torch.tensor(batched_num_img_tokens,
+ device=position_ids.device)
+ position_ids -= num_img_tokens[None]
+ else:
+ batched_position_ids = position_ids[0].split(q_seqlens)
+ for pos_ids, num_img_tok, imgs in zip(batched_position_ids,
+ batched_num_img_tokens,
+ input_imgs):
+ pos_ids -= num_img_tok
+ if len(imgs) == 0:
+ continue
+
+ seq_len = pos_ids.size(0)
+ start = pos_ids[0].cpu().item()
+ new_pos_ids = []
+
+ imgs = sorted(imgs, key=lambda img: img.start)
+ for img in imgs:
+ img_pad_pos = img.start + 1 - num_img_tok
+ num_pad = img.end - img.start - 2
+ new_pos_ids += list(range(start, img_pad_pos))
+ new_pos_ids += [img_pad_pos] * num_pad
+ start = img_pad_pos + 1
+ num_img_tok += num_pad
+
+ remain = seq_len - len(new_pos_ids)
+ new_pos_ids += list(range(start, start + remain))
+
+ new_pos_ids = pos_ids.new_tensor(new_pos_ids)
+ pos_ids[:] = new_pos_ids
+
+ position_ids = torch.cat(batched_position_ids)[None]
+ context.position_ids = position_ids
+
+ return new_model_metas
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""load weights."""
# modify from vllm
@@ -558,7 +883,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if 'transformer.vision' in name:
+ if '.query_key_value' in name:
+ param = params_dict[name]
+ q, k, v = param.weight_spliter(loaded_weight)
+ load_weight(param, q, shard_id='q')
+ load_weight(param, k, shard_id='k')
+ load_weight(param, v, shard_id='v')
+ else:
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
continue
+
if 'rotary_pos_emb.inv_freq' in name:
continue
if ('rotary_pos_emb.cos_cached' in name
@@ -581,3 +916,53 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
param = params_dict[name]
load_weight(param, loaded_weight)
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+
+class ChatGLMInputProcessor(BaseModelInputProcessor):
+ """input processor."""
+
+ def __init__(self, config: PretrainedConfig, dtype) -> None:
+ self.config = config
+ self.dtype = dtype
+
+ if hasattr(config, 'vision_config'):
+ vision_config = config.vision_config
+ self.image_size = vision_config['image_size']
+ self.patch_size = vision_config['patch_size']
+ self.num_patches = (self.image_size // self.patch_size)**2
+ self.num_positions = self.num_patches + 1
+ self.vision_token_num = self.num_patches // 4
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_multimodals: List[Dict[str, Any]] = None,
+ **kwargs) -> PreprocessInputResult:
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values'].to(self.dtype)
+ offset = input_mm['offset']
+ num_pad = input_mm['image_tokens']
+ image_token_id = input_mm.get('image_token_id', 0)
+ if isinstance(num_pad, torch.Tensor):
+ num_pad = num_pad.item()
+
+ mm_data = MultiModalTensor(
+ data=pixel_values,
+ start=offset,
+ end=offset + num_pad,
+ meta=dict(image_token_id=image_token_id))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py
index 6caf10df00..c460b8e44f 100644
--- a/lmdeploy/pytorch/models/cogvlm.py
+++ b/lmdeploy/pytorch/models/cogvlm.py
@@ -1,20 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from argparse import Namespace
from typing import Any, Iterable, List, Optional, Tuple
import torch
import torch.distributed as dist
+import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from lmdeploy.pytorch.distributed import get_world_rank
+from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor,
+ PreprocessInputResult)
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
SiluAndMul, build_rotary_embedding)
-from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear,
+from lmdeploy.pytorch.nn.linear import (build_colwise_linear,
+ build_merged_colwise_linear,
build_qkv_proj, build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .utils.cudagraph import CudaGraphMixin
+from .utils.model import DeployModelMixin
class VisionExpertAttention(nn.Module):
@@ -28,8 +35,9 @@ def __init__(self,
is_cogvlm2 = hasattr(config, 'num_multi_query_heads')
quantization_config = getattr(config, 'quantization_config', None)
num_heads = config.num_attention_heads
- num_key_value_heads = getattr(config, 'num_multi_query_heads',
- num_heads)
+ num_key_value_heads = getattr(config, 'num_key_value_heads', num_heads)
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
hidden_size = config.hidden_size
head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
self.hidden_size = hidden_size
@@ -46,7 +54,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
self.language_expert_query_key_value = build_qkv_proj(
hidden_size,
num_q_heads=num_heads,
@@ -56,7 +64,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
# rotary embedding
self.apply_rotary_pos_emb = ApplyRotaryEmb()
@@ -322,6 +330,283 @@ def forward(
return outputs
+class PatchEmbedding(nn.Module):
+ """vision embedding."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.proj = nn.Conv2d(config.in_channels,
+ config.hidden_size,
+ kernel_size=config.patch_size,
+ stride=config.patch_size,
+ dtype=dtype,
+ device=device)
+ self.cls_embedding = nn.Parameter(
+ torch.empty(1, config.hidden_size, dtype=dtype, device=device))
+ self.position_embedding = nn.Embedding(config.num_positions,
+ config.hidden_size,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, images):
+ """forward."""
+ x = self.proj(images)
+ x = x.flatten(2).transpose(1, 2)
+ cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
+ x = torch.cat((cls_token, x), dim=1)
+ x += self.position_embedding.weight.unsqueeze(0)
+ return x
+
+
+class EVA2CLIPAttention(nn.Module):
+ """vision attention."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
+ hidden_size = config.hidden_size
+ num_heads = config.num_heads
+ head_dim = config.hidden_size // config.num_heads
+ self.scale = head_dim**-0.5
+
+ # packed qkv
+ self.query_key_value = build_qkv_proj(
+ hidden_size,
+ num_q_heads=num_heads,
+ num_kv_heads=num_heads,
+ head_size=head_dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ )
+
+ # o_proj
+ self.dense = build_rowwise_linear(hidden_size,
+ hidden_size,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """forward."""
+ # qkv proj
+ qkv_states = self.query_key_value(hidden_states)
+ q, k, v = self.query_key_value.split_qkv(qkv_states)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
+
+ # o proj
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.flatten(-2, -1)
+ attn_output = self.dense(attn_output)
+ return attn_output
+
+
+class EVA2CLIPMLP(nn.Module):
+ """vision MLP."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ from transformers.activations import ACT2FN
+
+ # gate up
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.fc1 = build_colwise_linear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+
+ # silu and mul
+ if config.hidden_act in [
+ 'gelu', 'gelu_fast', 'quick_gelu', 'gelu_python'
+ ]:
+ self.activation_fn = nn.GELU()
+ else:
+ self.activation_fn = ACT2FN[config.hidden_act]
+
+ # down
+ self.fc2 = build_rowwise_linear(config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """forward."""
+ x = self.fc1(x)
+ x = self.activation_fn(x)
+ x = self.fc2(x)
+ return x
+
+
+class EVA2CLIPTransformerLayer(nn.Module):
+ """vision trans layer."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.input_layernorm = nn.LayerNorm(config.hidden_size,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+ self.attention = EVA2CLIPAttention(config, dtype=dtype, device=device)
+ self.mlp = EVA2CLIPMLP(config, dtype=dtype, device=device)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, hidden_states):
+ """forward."""
+ attention_input = hidden_states
+ attention_output = self.input_layernorm(
+ self.attention(attention_input))
+ hidden_states = attention_input + attention_output
+ mlp_input = hidden_states
+ mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
+ output = mlp_input + mlp_output
+ return output
+
+
+class EVA2CLIPTransformer(nn.Module):
+ """vision transformer."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.layers = nn.ModuleList([
+ EVA2CLIPTransformerLayer(config, dtype=dtype, device=device)
+ for _ in range(config.num_hidden_layers)
+ ])
+
+ def forward(self, hidden_states):
+ """forward."""
+ for layer_module in self.layers:
+ hidden_states = layer_module(hidden_states)
+ return hidden_states
+
+
+class GLU(nn.Module):
+ """GLU."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ in_features: int,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.linear_proj = nn.Linear(in_features,
+ config.hidden_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+ self.norm1 = nn.LayerNorm(config.hidden_size,
+ dtype=dtype,
+ device=device)
+ self.act1 = nn.GELU()
+ self.act2 = nn.functional.silu
+ self.dense_h_to_4h = nn.Linear(config.hidden_size,
+ config.intermediate_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+ self.gate_proj = nn.Linear(config.hidden_size,
+ config.intermediate_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+ self.dense_4h_to_h = nn.Linear(config.intermediate_size,
+ config.hidden_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, x):
+ x = self.linear_proj(x)
+ x = self.act1(self.norm1(x))
+ x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
+ x = self.dense_4h_to_h(x)
+ return x
+
+
+class EVA2CLIPModel(nn.Module):
+ """vision model."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ vision_config = Namespace(**config.vision_config)
+
+ self.patch_embedding = PatchEmbedding(vision_config,
+ dtype=dtype,
+ device=device)
+ self.transformer = EVA2CLIPTransformer(vision_config,
+ dtype=dtype,
+ device=device)
+ self.linear_proj = GLU(config,
+ in_features=vision_config.hidden_size,
+ dtype=dtype,
+ device=device)
+ self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
+ out_channels=vision_config.hidden_size,
+ kernel_size=2,
+ stride=2,
+ dtype=dtype,
+ device=device)
+ self.boi = nn.Parameter(
+ torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))
+ self.eoi = nn.Parameter(
+ torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))
+
+ def forward(self, images):
+ """forward."""
+ x = self.patch_embedding(images)
+ x = self.transformer(x)
+
+ x = x[:, 1:]
+
+ b, s, h = x.shape
+ grid_size = int(s**0.5)
+ x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
+ x = self.conv(x)
+
+ x = x.flatten(2).transpose(1, 2)
+ x = self.linear_proj(x)
+ boi = self.boi.expand(x.shape[0], -1, -1)
+ eoi = self.eoi.expand(x.shape[0], -1, -1)
+ x = torch.cat((boi, x, eoi), dim=1)
+ return x
+
+
class CogVLMModel(nn.Module):
"""model."""
@@ -353,6 +638,9 @@ def __init__(self,
dtype=dtype,
device=device)
+ # vision model
+ self.vision = EVA2CLIPModel(config, dtype=dtype, device=device)
+
# build rotary embedding
emb_type = RopeType.LinearScaling
rope_dim = config.hidden_size // config.num_attention_heads
@@ -371,6 +659,7 @@ def forward(
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
attn_metadata: Any = None,
+ images: torch.Tensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
lang_ids: torch.LongTensor = None,
vision_ids: torch.LongTensor = None,
@@ -379,7 +668,12 @@ def forward(
# token embedding
if inputs_embeds is None:
+ if images is not None:
+ images_features = self.vision(images)
+
inputs_embeds = self.embed_tokens(input_ids)
+ if vision_ids is not None:
+ inputs_embeds[0, vision_ids] = images_features.flatten(0, 1)
hidden_states = inputs_embeds
@@ -416,85 +710,7 @@ def get_input_embeddings(self):
VISION_TOKEN_TYPE = 1
-def get_vision_expert_mask(token_type_ids: torch.LongTensor):
- vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
- vision_token_mask[:, :-1] = (token_type_ids[:, :-1]
- == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:]
- == VISION_TOKEN_TYPE)
- language_token_mask = ~vision_token_mask
- return vision_token_mask, language_token_mask
-
-
-def build_position_ids(x: torch.BoolTensor) -> torch.LongTensor:
- tmp = x.clone()
- # image boi eoi token as LANGUAGE_TOKEN_TYPE
- is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
- is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (
- tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
- is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
- is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (
- tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
- is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
- tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
- # final position ids
- y = torch.zeros_like(x, dtype=torch.long)
- y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
- (tmp[:, 1:] == VISION_TOKEN_TYPE) &
- (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
- y = y.cumsum(dim=-1)
- return y
-
-
-def _get_cogvlm_position_ids(context):
- """get cogvlm position_ids."""
- q_seqlens = context.q_seqlens
- history_lengths = context.kv_seqlens - q_seqlens
- vision_input_info = context.vision_inputs
- position_id_offsets = (vision_input_info.history_image_token_lengths -
- vision_input_info.history_image_nums * 3)
- lang_ids = None
- vis_ids = None
- if context.is_decoding:
- position_ids = history_lengths - position_id_offsets
- else:
- if vision_input_info.input_embeddings is not None and len(
- vision_input_info.input_embeddings) > 0:
- starts = history_lengths - vision_input_info.history_lengths
- ends = starts + q_seqlens
- token_type_ids = vision_input_info.input_embedding_indexing.to(
- torch.int)
- history_position_lengths = (vision_input_info.history_lengths -
- position_id_offsets)
- position_ids_all = (history_position_lengths[:, None] +
- build_position_ids(token_type_ids))
- position_ids = torch.cat([
- pids[s:e]
- for (pids, s, e) in zip(position_ids_all, starts, ends)
- ])
- vision_token_mask_all, _ = get_vision_expert_mask(token_type_ids)
- vision_token_mask = torch.cat([
- masks[s:e]
- for (masks, s, e) in zip(vision_token_mask_all, starts, ends)
- ])
- mask_indexing = torch.arange(vision_token_mask.shape[-1],
- device=vision_token_mask.device)
- vis_ids = mask_indexing[vision_token_mask]
- lang_ids = mask_indexing[~vision_token_mask]
-
- else:
- position_ids = context.attention_mask.long().cumsum(-1) - 1
- position_ids += (history_lengths -
- position_id_offsets).unsqueeze(-1)
- device = position_ids.device
- position_ids_1d = [
- ids[:l] for ids, l in zip(position_ids.cpu(), q_seqlens.cpu())
- ]
- position_ids = torch.cat(position_ids_1d).to(device)
-
- return position_ids, lang_ids, vis_ids
-
-
-class CogVLMForCausalLM(nn.Module, CudaGraphMixin):
+class CogVLMForCausalLM(nn.Module, CudaGraphMixin, DeployModelMixin):
"""ModelForCausalLM."""
packed_modules_mapping = {
@@ -512,6 +728,8 @@ def __init__(self,
super().__init__()
self.config = config
self.ctx_mgr = ctx_mgr
+ # preprocessor
+ self.input_processor = CogVLMInputProcessor(self.config, dtype)
# build model
self.model = CogVLMModel(config, dtype=dtype, device=device)
# build lm_head
@@ -527,6 +745,7 @@ def forward(
position_ids: torch.Tensor,
past_key_values: List[List[torch.Tensor]],
attn_metadata: Any = None,
+ images: torch.Tensor = None,
inputs_embeds: torch.Tensor = None,
lang_ids: torch.LongTensor = None,
vision_ids: torch.LongTensor = None,
@@ -538,6 +757,7 @@ def forward(
position_ids=position_ids,
past_key_values=past_key_values,
attn_metadata=attn_metadata,
+ images=images,
inputs_embeds=inputs_embeds,
lang_ids=lang_ids,
vision_ids=vision_ids,
@@ -561,8 +781,36 @@ def prepare_inputs_for_generation(
"""prepare input."""
# get input_ids, position_ids and attention metadatas
input_ids = context.input_ids
- position_ids, lang_ids, vis_ids = _get_cogvlm_position_ids(context)
- position_ids = position_ids[None]
+
+ # position_ids, lang_ids, vis_ids = _get_cogvlm_position_ids(context)
+ position_ids = context.position_ids
+ lang_ids = None
+ vis_ids = None
+
+ # vision inputs
+ images = None
+ if context.input_multimodals is not None:
+ images = [
+ input_mm.get('image', [])
+ for input_mm in context.input_multimodals
+ ]
+ # flatten batch
+ images = [data for im_data in images for data in im_data]
+ if len(images) == 0:
+ images = None
+
+ if images is not None:
+ image_token_id = images[0].meta['image_token_id']
+ vis_mask = input_ids[0] == image_token_id
+ images = torch.stack([data.data for data in images])
+
+ # get lang_ids
+ vis_range = torch.arange(0,
+ input_ids.size(-1),
+ device=input_ids.device)
+ vis_ids = vis_range[vis_mask]
+ lang_ids = vis_range[~vis_mask]
+
attn_metadata = context.attn_metadata
# process vision embeddings
@@ -581,6 +829,7 @@ def prepare_inputs_for_generation(
position_ids=position_ids,
past_key_values=past_key_values,
attn_metadata=attn_metadata,
+ images=images,
inputs_embeds=inputs_embeds,
lang_ids=lang_ids,
vision_ids=vis_ids,
@@ -597,8 +846,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
- if 'model.vision' in name:
- continue
if 'rotary_emb.inv_freq' in name:
continue
if ('rotary_emb.cos_cached' in name
@@ -607,6 +854,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if self.config.tie_word_embeddings and 'lm_head.weight' in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if '.vision.' in name:
+ continue
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
@@ -620,6 +869,136 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
load_weight(param, q, shard_id='q')
load_weight(param, k, shard_id='k')
load_weight(param, v, shard_id='v')
+ elif '.query_key_value' in name:
+ param = params_dict[name]
+ q, k, v = param.weight_spliter(loaded_weight)
+ load_weight(param, q, shard_id='q')
+ load_weight(param, k, shard_id='k')
+ load_weight(param, v, shard_id='v')
else:
param = params_dict[name]
load_weight(param, loaded_weight)
+
+ def update_model_metas(self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None):
+ """update model meta."""
+ model_metas = context.model_metas
+ input_multimodals = context.input_multimodals
+ if input_multimodals is None:
+ input_imgs = [[] for _ in model_metas]
+ else:
+ input_imgs = []
+ for mm in input_multimodals:
+ if mm is None:
+ input_imgs.append([])
+ else:
+ input_imgs.append(mm.get('image', []))
+
+ config = self.config
+ image_size: int = config.vision_config['image_size']
+ patch_size: int = config.vision_config['patch_size']
+ vision_token_num = ((image_size // patch_size // 2) *
+ (image_size // patch_size // 2) + 2)
+ num_pad = vision_token_num - 3
+
+ batched_num_img_tokens = []
+ new_model_metas = []
+ for meta, imgs in zip(model_metas, input_imgs):
+ if meta is None:
+ num_img_tokens = 0
+ else:
+ num_img_tokens = meta.get('num_img_tokens', 0)
+
+ batched_num_img_tokens.append(num_img_tokens)
+
+ num_img_tokens += num_pad * len(imgs)
+ new_model_metas.append(dict(num_img_tokens=num_img_tokens))
+
+ # prepare cogvlm position_ids
+ q_seqlens = context.q_seqlens
+ position_ids = context.position_ids
+
+ if context.is_decoding or all(len(imgs) == 0 for imgs in input_imgs):
+ num_img_tokens = torch.tensor(batched_num_img_tokens,
+ device=position_ids.device)
+ position_ids -= num_img_tokens[None]
+ else:
+ batched_position_ids = position_ids[0].split(q_seqlens)
+ for pos_ids, num_img_tok, imgs in zip(batched_position_ids,
+ batched_num_img_tokens,
+ input_imgs):
+ pos_ids -= num_img_tok
+ if len(imgs) == 0:
+ continue
+
+ seq_len = pos_ids.size(0)
+ start = pos_ids[0].cpu().item()
+ new_pos_ids = []
+
+ imgs = sorted(imgs, key=lambda img: img.start)
+ for img in imgs:
+ img_pad_pos = img.start + 1 - num_img_tok
+ num_pad = img.end - img.start - 2
+ new_pos_ids += list(range(start, img_pad_pos))
+ new_pos_ids += [img_pad_pos] * num_pad
+ start = img_pad_pos + 1
+ num_img_tok += num_pad
+
+ remain = seq_len - len(new_pos_ids)
+ new_pos_ids += list(range(start, start + remain))
+
+ new_pos_ids = pos_ids.new_tensor(new_pos_ids)
+ pos_ids[:] = new_pos_ids
+
+ position_ids = torch.cat(batched_position_ids)[None]
+ context.position_ids = position_ids
+
+ return new_model_metas
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+
+class CogVLMInputProcessor(BaseModelInputProcessor):
+ """input processor."""
+
+ def __init__(self, config: PretrainedConfig, dtype) -> None:
+ self.config = config
+ self.dtype = dtype
+ image_size: int = config.vision_config['image_size']
+ patch_size: int = config.vision_config['patch_size']
+ self.vision_token_num = ((image_size // patch_size // 2) *
+ (image_size // patch_size // 2) + 2)
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_multimodals=None,
+ **kwargs) -> PreprocessInputResult:
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values'].to(self.dtype)
+ offset = input_mm['offset']
+ image_token_id = input_mm.get('image_token_id', 0)
+ num_pad = input_mm['image_tokens']
+ if isinstance(num_pad, torch.Tensor):
+ num_pad = num_pad.item()
+
+ mm_data = MultiModalTensor(
+ data=pixel_values,
+ start=offset,
+ end=offset + num_pad,
+ meta=dict(image_token_id=image_token_id))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py
index 34debae229..66f68d90e5 100644
--- a/lmdeploy/pytorch/models/deepseek_v2.py
+++ b/lmdeploy/pytorch/models/deepseek_v2.py
@@ -90,6 +90,9 @@ def __init__(self,
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
+ num_key_value_heads = getattr(config, 'num_key_value_heads', 1)
if self.q_lora_rank is None:
self.q_proj = build_colwise_linear(
@@ -157,10 +160,9 @@ def __init__(self,
self.num_heads,
config.kv_lora_rank + self.qk_rope_head_dim,
scale=self.softmax_scale,
- num_kv_heads=1,
+ num_kv_heads=num_key_value_heads,
v_head_size=config.kv_lora_rank,
- replicate_kv=True,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
self.vc = DeepseekV2BMM(self.num_heads,
config.kv_lora_rank,
diff --git a/lmdeploy/pytorch/models/falcon.py b/lmdeploy/pytorch/models/falcon.py
index 8f8659dc5e..2d2edb9f49 100644
--- a/lmdeploy/pytorch/models/falcon.py
+++ b/lmdeploy/pytorch/models/falcon.py
@@ -31,34 +31,31 @@ def __init__(self,
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
- self.num_kv_heads = self.num_attention_heads
+ self.num_kv_heads = getattr(config, 'num_kv_heads',
+ config.num_attention_heads)
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
self.head_size = (self.hidden_size // config.num_attention_heads)
- self.multi_query_attention = config.multi_query
- if self.multi_query_attention:
- self.num_kv_heads = 1
self.query_key_value = build_qkv_proj(
config.hidden_size,
num_q_heads=self.num_attention_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
bias=config.bias,
- replicate_kv=self.multi_query_attention,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
# apply rotary
self.apply_rotary_pos_emb = ApplyRotaryEmb()
self.rotary = config.rotary
# attention
- self.attn_fwd = Attention(
- self.num_attention_heads,
- self.head_size,
- num_kv_heads=self.num_kv_heads,
- alibi=config.alibi,
- )
+ self.attn_fwd = Attention(self.num_attention_heads,
+ self.head_size,
+ num_kv_heads=self.num_kv_heads,
+ alibi=config.alibi)
# o_proj
self.dense = build_rowwise_linear(self.hidden_size,
diff --git a/lmdeploy/pytorch/models/gemma.py b/lmdeploy/pytorch/models/gemma.py
index ca36f15651..86be85669e 100644
--- a/lmdeploy/pytorch/models/gemma.py
+++ b/lmdeploy/pytorch/models/gemma.py
@@ -31,7 +31,8 @@ def __init__(self,
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
head_dim = config.head_dim
-
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
# packed qkv
self.qkv_proj = build_qkv_proj(
hidden_size,
@@ -42,7 +43,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
# rotary embedding
self.apply_rotary_pos_emb = ApplyRotaryEmb()
diff --git a/lmdeploy/pytorch/models/internlm.py b/lmdeploy/pytorch/models/internlm.py
index 99c622e4ac..fdee716b4a 100644
--- a/lmdeploy/pytorch/models/internlm.py
+++ b/lmdeploy/pytorch/models/internlm.py
@@ -28,7 +28,8 @@ def __init__(self,
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
-
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
# packed qkv
self.qkv_proj = build_qkv_proj(
hidden_size,
@@ -39,7 +40,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
# rotary embedding
self.apply_rotary_pos_emb = ApplyRotaryEmb()
diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py
index 6cbc2ccff3..db246331a1 100644
--- a/lmdeploy/pytorch/models/internlm2.py
+++ b/lmdeploy/pytorch/models/internlm2.py
@@ -28,7 +28,8 @@ def __init__(self,
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
head_dim = hidden_size // num_heads
-
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
# packed qkv
self.wqkv = build_qkv_proj(
hidden_size,
@@ -39,6 +40,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
+ num_replicate_kv_heads=num_replicate_kv_heads,
)
# rotary embedding
@@ -395,6 +397,32 @@ def prepare_inputs_for_generation(
inputs_embeds=inputs_embeds,
)
+ def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
+ adapter_id: int):
+ """load lora weights."""
+
+ from lmdeploy.pytorch.adapter.adapter import load_lora_weights
+
+ num_heads = self.config.num_attention_heads
+ num_key_value_heads = self.config.num_key_value_heads
+ hidden_size = self.config.hidden_size
+ head_dim = hidden_size // num_heads
+ group_size = num_heads // num_key_value_heads
+
+ def _rearange_wqkv(weights):
+ for name, loaded_weight in weights:
+ if 'wqkv.lora_B' in name:
+ loaded_weight = loaded_weight.unflatten(
+ 0, (-1, 2 + group_size, head_dim))
+ q = loaded_weight[:, :-2].flatten(0, 2)
+ k = loaded_weight[:, -2].flatten(0, 1)
+ v = loaded_weight[:, -1].flatten(0, 1)
+ loaded_weight = torch.cat([q, k, v], dim=0)
+ yield name, loaded_weight
+
+ weights_iter = _rearange_wqkv(weights)
+ load_lora_weights(self, weights_iter, adapter_id)
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""load weights."""
# modify from vllm
diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py
index 70dd8f2159..1059569a09 100644
--- a/lmdeploy/pytorch/models/internvl.py
+++ b/lmdeploy/pytorch/models/internvl.py
@@ -1,17 +1,311 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Any, Iterable, List, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
+import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig
+from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor,
+ PreprocessInputResult)
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
+from lmdeploy.pytorch.nn import LayerNorm, RMSNorm
+from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_qkv_proj,
+ build_rowwise_linear)
+from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .patch import build_model_from_hf_config
from .utils.cudagraph import CudaGraphMixin
+from .utils.model import DeployModelMixin
-class InternVLChatModel(nn.Module, CudaGraphMixin):
+class InternVisionEmbeddings(nn.Module):
+ """intern vision embedding."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(
+ torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device), )
+
+ self.patch_embedding = nn.Conv2d(in_channels=3,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ dtype=dtype,
+ device=device)
+
+ self.num_patches = (self.image_size // self.patch_size)**2
+ self.num_positions = self.num_patches + 1
+
+ self.position_embedding = nn.Parameter(
+ torch.empty(1,
+ self.num_positions,
+ self.embed_dim,
+ dtype=dtype,
+ device=device))
+
+ def _get_pos_embed(self, pos_embed, H, W):
+ target_dtype = pos_embed.dtype
+ pos_embed = pos_embed.float().reshape(
+ 1, self.image_size // self.patch_size,
+ self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
+ pos_embed = F.interpolate(pos_embed,
+ size=(H, W),
+ mode='bicubic',
+ align_corners=False).reshape(
+ 1, -1, H * W).permute(0, 2,
+ 1).to(target_dtype)
+ return pos_embed
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(
+ pixel_values) # shape = [*, channel, width, height]
+ batch_size, _, height, width = patch_embeds.shape
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+ class_embeds = self.class_embedding.expand(batch_size, 1,
+ -1).to(target_dtype)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ position_embedding = torch.cat([
+ self.position_embedding[:, :1, :],
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height,
+ width)
+ ],
+ dim=1)
+ embeddings = embeddings + position_embedding.to(target_dtype)
+ return embeddings
+
+
+NORM2FN = {
+ 'rms_norm': RMSNorm,
+ 'layer_norm': LayerNorm,
+}
+
+
+class InternAttention(nn.Module):
+ """intern vl attention."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+
+ self.qkv = build_qkv_proj(
+ self.embed_dim,
+ num_q_heads=self.num_heads,
+ num_kv_heads=self.num_heads,
+ head_size=self.head_dim,
+ bias=config.qkv_bias,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ )
+
+ self.qk_normalization = config.qk_normalization
+
+ if self.qk_normalization:
+ self.q_norm = RMSNorm(
+ self.embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device,
+ )
+ self.k_norm = RMSNorm(
+ self.embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device,
+ )
+
+ self.scale = self.head_dim**-0.5
+
+ # o_proj
+ self.proj = build_rowwise_linear(self.embed_dim,
+ self.embed_dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True,
+ tp_align_size=self.head_dim)
+
+ def forward(self, hidden_states):
+ """forward."""
+
+ # qkv proj
+ qkv_states = self.qkv(hidden_states)
+ q, k, v = self.qkv.split_qkv(qkv_states)
+
+ if self.qk_normalization:
+ q_shape = q.shape
+ q = self.q_norm(q.flatten(-2, -1)).view(q_shape)
+ k = self.k_norm(k.flatten(-2, -1)).view(q_shape)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
+
+ # o proj
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.flatten(-2, -1)
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class InternMLP(nn.Module):
+ """intern vl mlp."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ from transformers.activations import ACT2FN
+ self.config = config
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.act = ACT2FN[config.hidden_act]
+
+ self.fc1 = build_colwise_linear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+
+ self.fc2 = build_rowwise_linear(config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class InternVisionEncoderLayer(nn.Module):
+ """intern vision encoder layer."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.norm_type = getattr(config, 'norm_type', 'rms_norm')
+
+ self.attn = InternAttention(config, dtype=dtype, device=device)
+ self.mlp = InternMLP(config, dtype=dtype, device=device)
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+
+ self.ls1 = nn.Parameter(
+ torch.empty(self.embed_dim, dtype=dtype, device=device))
+ self.ls2 = nn.Parameter(
+ torch.empty(self.embed_dim, dtype=dtype, device=device))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ):
+ """forward."""
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
+
+ hidden_states = hidden_states + self.mlp(
+ self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2
+
+ return hidden_states
+
+
+class InternVisionEncoder(nn.Module):
+ """intern vision encoder."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([
+ InternVisionEncoderLayer(config, dtype=dtype, device=device)
+ for idx in range(config.num_hidden_layers)
+ ])
+
+ def forward(
+ self,
+ inputs_embeds,
+ ):
+ """forward."""
+ hidden_states = inputs_embeds
+ for _, encoder_layer in enumerate(self.layers):
+ layer_outputs = encoder_layer(hidden_states, )
+ hidden_states = layer_outputs
+ return hidden_states
+
+
+class InternVisionModel(nn.Module):
+ """intern vision model."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+
+ self.embeddings = InternVisionEmbeddings(config,
+ dtype=dtype,
+ device=device)
+ self.encoder = InternVisionEncoder(config, dtype=dtype, device=device)
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ ):
+ """forward."""
+ assert pixel_values.dim() == 4
+ hidden_states = self.embeddings(pixel_values)
+
+ encoder_outputs = self.encoder(inputs_embeds=hidden_states)
+ last_hidden_state = encoder_outputs
+
+ return last_hidden_state
+
+
+class InternVLChatModel(nn.Module, DeployModelMixin, CudaGraphMixin):
def __init__(self,
config: PretrainedConfig,
@@ -21,31 +315,106 @@ def __init__(self,
super().__init__()
self.config = config
self.ctx_mgr = ctx_mgr
+ self.select_layer = config.select_layer
+
llm_config = config.llm_config
+ self.llm_arch_name = llm_config.architectures[0]
+ self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
+
+ vision_config = config.vision_config
+ if self.is_mono:
+ from .internvl_patch import InternVisionPatchModel
+ self.vision_model = InternVisionPatchModel(
+ vision_config,
+ dtype=dtype,
+ device=device,
+ )
+ else:
+ self.vision_model = InternVisionModel(vision_config,
+ dtype=dtype,
+ device=device)
+
self.language_model = build_model_from_hf_config(llm_config,
dtype=dtype,
device=device)
- self.llm_arch_name = llm_config.architectures[0]
+ vit_hidden_size = config.vision_config.hidden_size
+ llm_hidden_size = config.llm_config.hidden_size
+ self.downsample_ratio = config.downsample_ratio
+ self.mlp1 = nn.Sequential(
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2,
+ dtype=dtype,
+ device=device),
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
+ llm_hidden_size,
+ dtype=dtype,
+ device=device), nn.GELU(),
+ nn.Linear(llm_hidden_size,
+ llm_hidden_size,
+ dtype=dtype,
+ device=device))
# for Mono-InternVL
- self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
if self.is_mono:
assert dtype != torch.float16, (
'Currently Mono-InternVL does not support FP16 due to'
'numerical instability. Please use BF16 instead.')
+ self.input_processor = InternVLInputProcessor(self.config, dtype)
+
+ def pixel_shuffle(self, x, scale_factor=0.5):
+ n, w, h, c = x.size()
+ # N, W, H, C --> N, W, H * scale, C // scale
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
+ x = x.permute(0, 2, 1, 3).contiguous()
+ # N, H * scale, W, C // scale -->
+ # N, H * scale, W * scale, C // (scale ** 2)
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
+ int(c / (scale_factor * scale_factor)))
+ x = x.permute(0, 2, 1, 3).contiguous()
+ return x
+
+ def extract_feature(self, pixel_values):
+ """extract vision feature."""
+ assert self.select_layer == -1
+ vit_embeds = self.vision_model(pixel_values)
+ if self.is_mono:
+ if int(vit_embeds.shape[1]**0.5)**2 != vit_embeds.shape[1]:
+ vit_embeds = vit_embeds[:, 1:, :]
+ else:
+ vit_embeds = vit_embeds[:, 1:, :]
+
+ h = w = int(vit_embeds.shape[1]**0.5)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
+ vit_embeds = self.pixel_shuffle(vit_embeds,
+ scale_factor=self.downsample_ratio)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
+ vit_embeds.shape[-1])
+ vit_embeds = self.mlp1(vit_embeds)
+ return vit_embeds
+
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: List[List[torch.Tensor]],
attn_metadata: Any = None,
+ pixel_values: torch.Tensor = None,
+ image_mask: torch.Tensor = None,
inputs_embeds: torch.Tensor = None,
vision_embedding_indexing: torch.Tensor = None,
text_embedding_indexing: torch.Tensor = None,
**kwargs,
):
+ if inputs_embeds is None and pixel_values is not None:
+ # extract feature
+ vit_embeds = self.extract_feature(pixel_values)
+ lang_embeds = self.language_model.get_input_embeddings()(input_ids)
+ lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds)
+
+ inputs_embeds = lang_embeds
+
if self.is_mono:
return self.language_model.forward(
input_ids=input_ids,
@@ -80,11 +449,38 @@ def prepare_inputs_for_generation(
input_ids = context.input_ids
position_ids = context.position_ids
attn_metadata = context.attn_metadata
- # get inputs from context
vision_embeddings = context.input_embeddings
- vision_embedding_indexing = context.input_embedding_indexing
+ vision_embedding_indexing = None
+
+ # vision inputs
+ pixel_values = None
+ image_mask = None
+ if context.input_multimodals is not None:
+ pixel_values = [
+ input_mm.get('image', [])
+ for input_mm in context.input_multimodals
+ ]
+ # flatten batch
+ pixel_values = [
+ data for im_data in pixel_values for data in im_data
+ ]
+ if len(pixel_values) > 0:
+ image_token_id = pixel_values[0].meta['image_token_id']
+ image_mask = input_ids == image_token_id
+ pixel_values = torch.cat([data.data for data in pixel_values])
+ else:
+ pixel_values = None
+ image_mask = None
+ if self.is_mono and pixel_values is not None:
+ vision_embedding_indexing = torch.arange(input_ids.shape[1],
+ device=input_ids.device)
+ vision_embedding_indexing = vision_embedding_indexing[
+ image_mask[0]]
+
+ # get inputs from context
if vision_embeddings is not None and len(vision_embeddings) > 0:
+ vision_embedding_indexing = context.input_embedding_indexing
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds[:,
@@ -104,6 +500,8 @@ def prepare_inputs_for_generation(
position_ids=position_ids,
past_key_values=past_key_values,
attn_metadata=attn_metadata,
+ pixel_values=pixel_values,
+ image_mask=image_mask,
inputs_embeds=inputs_embeds,
vision_embedding_indexing=vision_embedding_indexing,
text_embedding_indexing=text_embedding_indexing,
@@ -114,18 +512,96 @@ def prepare_inputs_for_generation(
position_ids=position_ids,
past_key_values=past_key_values,
attn_metadata=attn_metadata,
+ pixel_values=pixel_values,
+ image_mask=image_mask,
inputs_embeds=inputs_embeds,
)
+ def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
+ adapter_id: int):
+ """load lora weights."""
+
+ if hasattr(self.language_model, 'load_lora_weights'):
+ return self.language_model.load_lora_weights(weights, adapter_id)
+ else:
+ from lmdeploy.pytorch.adapter.adapter import load_lora_weights
+
+ return load_lora_weights(weights, adapter_id)
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""load weights."""
- prefix_length = len('language_model.')
+ lang_prefix = 'language_model.'
+ params_dict = dict(self.named_parameters())
+ for name, loaded_weight in weights:
+ if name.startswith(lang_prefix):
+ continue
+
+ if 'qkv' in name:
+ param = params_dict[name]
+ q, k, v = param.weight_spliter(loaded_weight)
+ load_weight(param, q, shard_id='q')
+ load_weight(param, k, shard_id='k')
+ load_weight(param, v, shard_id='v')
+ else:
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
+
+ lang_prefix_length = len(lang_prefix)
new_weights = dict()
for key, val in weights:
- if not key.startswith('language_model.'):
+ if not key.startswith(lang_prefix):
continue
- new_key = key[prefix_length:]
+ new_key = key[lang_prefix_length:]
new_weights[new_key] = val
self.language_model.load_weights(new_weights.items())
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+
+class InternVLInputProcessor(BaseModelInputProcessor):
+ """internvl input processor."""
+
+ def __init__(self, config: PretrainedConfig, dtype) -> None:
+ self.config = config
+ self.dtype = dtype
+
+ vision_config = config.vision_config
+ self.image_size = vision_config.image_size
+ self.patch_size = vision_config.patch_size
+ self.num_patches = (self.image_size // self.patch_size)**2
+ self.num_positions = self.num_patches + 1
+ self.vision_token_num = self.num_patches // 4
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_multimodals: List[Dict[str, Any]] = None,
+ **kwargs) -> PreprocessInputResult:
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values'].to(self.dtype)
+ offset = input_mm['offset']
+ image_token_id = input_mm.get('image_token_id', 0)
+ num_pad = input_mm['image_tokens']
+ if isinstance(num_pad, torch.Tensor):
+ num_pad = num_pad.item()
+
+ mm_data = MultiModalTensor(
+ data=pixel_values,
+ start=offset,
+ end=offset + num_pad,
+ meta=dict(image_token_id=image_token_id))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
diff --git a/lmdeploy/pytorch/models/internvl_patch.py b/lmdeploy/pytorch/models/internvl_patch.py
new file mode 100644
index 0000000000..d13ad2d39b
--- /dev/null
+++ b/lmdeploy/pytorch/models/internvl_patch.py
@@ -0,0 +1,96 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from transformers.configuration_utils import PretrainedConfig
+
+
+class InternVisionEmbeddings(nn.Module):
+ """mono vision."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(
+ torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device), )
+
+ self.patch_embedding = nn.Conv2d(in_channels=3,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ dtype=dtype,
+ device=device)
+
+ self.num_patches = (self.image_size // self.patch_size)**2
+ self.num_positions = self.num_patches + 1
+
+ self.position_embedding = nn.Parameter(
+ torch.empty(1,
+ self.num_positions,
+ self.embed_dim,
+ dtype=dtype,
+ device=device))
+
+ def _get_pos_embed(self, pos_embed, H, W):
+ target_dtype = pos_embed.dtype
+ pos_embed = pos_embed.float().reshape(
+ 1, self.image_size // self.patch_size,
+ self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
+ pos_embed = F.interpolate(pos_embed,
+ size=(H, W),
+ mode='bicubic',
+ align_corners=False)
+ pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2,
+ 1).to(target_dtype)
+ return pos_embed
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(
+ pixel_values) # shape = [*, channel, width, height]
+ batch_size, _, height, width = patch_embeds.shape
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+ class_embeds = self.class_embedding.expand(batch_size, 1,
+ -1).to(target_dtype)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ position_embedding = torch.cat([
+ self.position_embedding[:, :1, :],
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height,
+ width)
+ ],
+ dim=1)
+ embeddings = embeddings + position_embedding.to(target_dtype)
+ return embeddings
+
+
+class InternVisionPatchModel(nn.Module):
+ """mono vision."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.embeddings = InternVisionEmbeddings(config,
+ dtype=dtype,
+ device=device)
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ ):
+ if len(pixel_values.shape) != 4:
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
+
+ hidden_states = self.embeddings(pixel_values)[:, 1:]
+ return hidden_states
diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py
index f38c5ef02b..1a98c02f03 100644
--- a/lmdeploy/pytorch/models/llama.py
+++ b/lmdeploy/pytorch/models/llama.py
@@ -29,7 +29,8 @@ def __init__(self,
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
-
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
# packed qkv
self.qkv_proj = build_qkv_proj(
hidden_size,
@@ -40,6 +41,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
+ num_replicate_kv_heads=num_replicate_kv_heads,
)
# rotary embedding
@@ -450,22 +452,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
param = params_dict[name]
load_weight(param, loaded_weight)
-
-
-class LlavaLlamaForCausalLM(LlamaForCausalLM):
- """llava llama for causallm."""
-
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- """load weights."""
-
- new_weights = dict()
- for key, val in weights:
- if key.startswith('model.vision_tower'):
- continue
- if key.startswith('model.mm_projector'):
- continue
- if key.startswith('model.image_newline'):
- continue
- new_weights[key] = val
-
- super().load_weights(new_weights.items())
diff --git a/lmdeploy/pytorch/models/llava.py b/lmdeploy/pytorch/models/llava.py
index 56cb5ca675..751f7343ec 100644
--- a/lmdeploy/pytorch/models/llava.py
+++ b/lmdeploy/pytorch/models/llava.py
@@ -1,17 +1,443 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Any, Iterable, List, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
+import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_outputs import BaseModelOutputWithPooling
+from transformers.models.llava.configuration_llava import LlavaConfig
+from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor,
+ PreprocessInputResult)
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
+from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_qkv_proj,
+ build_rowwise_linear)
+from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .patch import build_model_from_hf_config
from .utils.cudagraph import CudaGraphMixin
+from .utils.model import DeployModelMixin
-class LlavaForConditionalGeneration(nn.Module, CudaGraphMixin):
+class LlavaMultiModalProjector(nn.Module):
+
+ def __init__(self,
+ config: LlavaConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ from transformers.activations import ACT2FN
+
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size,
+ config.text_config.hidden_size,
+ bias=True,
+ dtype=dtype,
+ device=device)
+ self.act = ACT2FN[config.projector_hidden_act]
+ self.linear_2 = nn.Linear(config.text_config.hidden_size,
+ config.text_config.hidden_size,
+ bias=True,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, image_features):
+ hidden_states = self.linear_1(image_features)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+class CLIPVisionEmbeddings(nn.Module):
+ """clip vision embedding."""
+
+ def __init__(self,
+ config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(
+ torch.empty(self.embed_dim, dtype=dtype, device=device))
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ bias=False,
+ dtype=dtype,
+ device=device,
+ )
+
+ self.num_patches = (self.image_size // self.patch_size)**2
+ self.num_positions = self.num_patches + 1
+ self.position_embedding = nn.Embedding(
+ self.num_positions,
+ self.embed_dim,
+ dtype=dtype,
+ device=device,
+ )
+ self.register_buffer('position_ids',
+ torch.arange(self.num_positions,
+ device=device).expand((1, -1)),
+ persistent=False)
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
+ width: int) -> torch.Tensor:
+ """This method allows to interpolate the pre-trained position
+ encodings, to be able to use the model on higher resolution images.
+
+ This method is also adapted to support torch.jit tracing.
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ position_embedding = self.position_embedding.weight.unsqueeze(0)
+ num_positions = position_embedding.shape[1] - 1
+
+ # always interpolate when tracing
+ # to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing(
+ ) and num_patches == num_positions and height == width:
+ return self.position_embedding(self.position_ids)
+
+ from transformers.utils import torch_int
+
+ class_pos_embed = position_embedding[:, :1]
+ patch_pos_embed = position_embedding[:, 1:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions,
+ sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode='bicubic',
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(self,
+ pixel_values: torch.FloatTensor,
+ interpolate_pos_encoding=False) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ if not interpolate_pos_encoding and (height != self.image_size
+ or width != self.image_size):
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f' ({self.image_size}*{self.image_size}).')
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(
+ dtype=target_dtype)) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(
+ embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embedding(
+ self.position_ids)
+ return embeddings
+
+
+class CLIPAttention(nn.Module):
+ """clip attention."""
+
+ def __init__(self,
+ config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+
+ self.qkv_proj = build_qkv_proj(
+ self.embed_dim,
+ num_q_heads=self.num_heads,
+ num_kv_heads=self.num_heads,
+ head_size=self.head_dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ )
+
+ self.scale = self.head_dim**-0.5
+
+ # o_proj
+ self.out_proj = build_rowwise_linear(self.embed_dim,
+ self.embed_dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ """forward."""
+ # qkv proj
+ qkv_states = self.qkv_proj(hidden_states)
+ q, k, v = self.qkv_proj.split_qkv(qkv_states)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ if attention_mask is not None and causal_attention_mask is not None:
+ attn_mask = attention_mask + causal_attention_mask
+ elif causal_attention_mask is not None:
+ attn_mask = causal_attention_mask
+ else:
+ attn_mask = attention_mask
+
+ attn_output = F.scaled_dot_product_attention(q,
+ k,
+ v,
+ attn_mask=attn_mask,
+ scale=self.scale)
+
+ # o proj
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.flatten(-2, -1)
+ attn_output = self.out_proj(attn_output)
+ return attn_output
+
+
+class CLIPMLP(nn.Module):
+ """clip mlp."""
+
+ def __init__(self,
+ config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ quantization_config = getattr(config, 'quantization_config', None)
+ from transformers.activations import ACT2FN
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = build_colwise_linear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+ self.fc2 = build_rowwise_linear(
+ config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """forward."""
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class CLIPEncoderLayer(nn.Module):
+ """clip encoder layer."""
+
+ def __init__(self,
+ config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = CLIPAttention(config, dtype=dtype, device=device)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+ self.mlp = CLIPMLP(config, dtype=dtype, device=device)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ causal_attention_mask: torch.Tensor,
+ ):
+ """forward."""
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class CLIPEncoder(nn.Module):
+ """clip encoder."""
+
+ def __init__(self,
+ config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([
+ CLIPEncoderLayer(config, dtype=dtype, device=device)
+ for _ in range(config.num_hidden_layers)
+ ])
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ vision_feature_layer: int = -1,
+ ):
+ """forward."""
+ hidden_states = inputs_embeds
+ num_vision_layers = len(self.layers) + vision_feature_layer + 1
+ for _, encoder_layer in enumerate(self.layers[:num_vision_layers]):
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ )
+
+ hidden_states = layer_outputs
+
+ return hidden_states
+
+
+class CLIPVisionTransformer(nn.Module):
+ """clip vision transformer."""
+
+ def __init__(self,
+ config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = CLIPVisionEmbeddings(config,
+ dtype=dtype,
+ device=device)
+ self.pre_layrnorm = nn.LayerNorm(embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+ self.encoder = CLIPEncoder(config, dtype=dtype, device=device)
+ self.post_layernorm = nn.LayerNorm(embed_dim,
+ eps=config.layer_norm_eps,
+ dtype=dtype,
+ device=device)
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ interpolate_pos_encoding: bool = False,
+ vision_feature_layer: int = -1,
+ ) -> BaseModelOutputWithPooling:
+ """forward."""
+ hidden_states = self.embeddings(
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+ hidden_states = self.pre_layrnorm(hidden_states)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ vision_feature_layer=vision_feature_layer)
+
+ last_hidden_state = encoder_outputs
+ pooled_output = last_hidden_state[:, 0, :]
+ pooled_output = self.post_layernorm(pooled_output)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=None,
+ attentions=None,
+ )
+
+
+class CLIPVisionModel(nn.Module):
+ """clip vision model."""
+
+ def __init__(self,
+ config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.vision_model = CLIPVisionTransformer(config,
+ dtype=dtype,
+ device=device)
+
+ def forward(self,
+ pixel_values: torch.FloatTensor,
+ interpolate_pos_encoding: bool = False,
+ vision_feature_layer: int = -1,
+ **kwargs):
+ """forward."""
+ return self.vision_model(
+ pixel_values,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ vision_feature_layer=vision_feature_layer)
+
+
+def build_vision_model(vision_config,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ """build vision model."""
+ model_type = vision_config.model_type
+
+ if model_type == 'clip_vision_model':
+ return CLIPVisionModel(vision_config, dtype, device)
+ else:
+ raise NotImplementedError(f'<{model_type}> is not implemented.')
+
+
+class LlavaForConditionalGeneration(nn.Module, CudaGraphMixin,
+ DeployModelMixin):
def __init__(self,
config: PretrainedConfig,
@@ -22,19 +448,67 @@ def __init__(self,
self.config = config
self.ctx_mgr = ctx_mgr
text_config = config.text_config
+
+ self.vision_tower = build_vision_model(config.vision_config,
+ dtype=dtype,
+ device=device)
+
self.language_model = build_model_from_hf_config(text_config,
dtype=dtype,
device=device)
+ self.multi_modal_projector = LlavaMultiModalProjector(config,
+ dtype=dtype,
+ device=device)
+
+ self.input_processor = LLavaInputProcessor(config, dtype)
+
+ def get_image_features(self,
+ pixel_values,
+ vision_feature_layer: int = -1,
+ vision_feature_select_strategy: str = 'default'):
+ """get image features."""
+ selected_image_feature = self.vision_tower(
+ pixel_values, vision_feature_layer=vision_feature_layer)[0]
+ if vision_feature_select_strategy == 'default':
+ selected_image_feature = selected_image_feature[:, 1:]
+ elif vision_feature_select_strategy == 'full':
+ selected_image_feature = selected_image_feature
+ else:
+ raise ValueError(
+ f'Unexpected select feature strategy: {vision_feature_select_strategy}' # noqa: E501
+ )
+ image_features = self.multi_modal_projector(selected_image_feature)
+ image_features = image_features.flatten(0, 1)[None]
+
+ return image_features
+
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: List[List[torch.Tensor]],
attn_metadata: Any = None,
+ pixel_values: torch.Tensor = None,
+ image_mask: torch.Tensor = None,
inputs_embeds: torch.Tensor = None,
**kwargs,
):
+ if inputs_embeds is None:
+ image_features = None
+ if pixel_values is not None:
+ vision_feature_layer = self.config.vision_feature_layer
+ select_strategy = self.config.vision_feature_select_strategy
+ image_features = self.get_image_features(
+ pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=select_strategy)
+ inputs_embeds = self.language_model.get_input_embeddings()(
+ input_ids)
+ if pixel_values is not None:
+ inputs_embeds.masked_scatter_(image_mask[..., None],
+ image_features)
+
return self.language_model.forward(input_ids=input_ids,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
@@ -59,6 +533,27 @@ def prepare_inputs_for_generation(
input_ids = context.input_ids
position_ids = context.position_ids
attn_metadata = context.attn_metadata
+
+ # vision inputs
+ pixel_values = None
+ image_mask = None
+ if context.input_multimodals is not None:
+ pixel_values = [
+ input_mm.get('image', [])
+ for input_mm in context.input_multimodals
+ ]
+ # flatten batch
+ pixel_values = [
+ data for im_data in pixel_values for data in im_data
+ ]
+ if len(pixel_values) > 0:
+ image_token_id = pixel_values[0].meta['image_token_id']
+ image_mask = input_ids == image_token_id
+ pixel_values = torch.cat([data.data for data in pixel_values])
+ else:
+ pixel_values = None
+ image_mask = None
+
# get inputs from context
vision_embeddings = context.input_embeddings
vision_embedding_indexing = context.input_embedding_indexing
@@ -75,18 +570,404 @@ def prepare_inputs_for_generation(
position_ids=position_ids,
past_key_values=past_key_values,
attn_metadata=attn_metadata,
+ pixel_values=pixel_values,
+ image_mask=image_mask,
inputs_embeds=inputs_embeds,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""load weights."""
- prefix_length = len('language_model.')
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ('.qkv_proj', '.q_proj', 'q'),
+ ('.qkv_proj', '.k_proj', 'k'),
+ ('.qkv_proj', '.v_proj', 'v'),
+ ]
+
+ # vis model
+ lang_prefix = 'language_model.'
+ params_dict = dict(self.named_parameters())
+ for name, loaded_weight in weights:
+ if name.startswith(lang_prefix):
+ continue
+
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ load_weight(param, loaded_weight, shard_id=shard_id)
+ break
+ else:
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
+
+ # language model
+ prefix_length = len(lang_prefix)
new_weights = dict()
for key, val in weights:
- if not key.startswith('language_model.'):
+ if not key.startswith(lang_prefix):
continue
new_key = key[prefix_length:]
new_weights[new_key] = val
self.language_model.load_weights(new_weights.items())
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+
+class LLavaInputProcessor(BaseModelInputProcessor):
+ """llava input processor."""
+
+ def __init__(self, config: PretrainedConfig, dtype) -> None:
+ self.config = config
+ self.dtype = dtype
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_multimodals: List[Dict[str, Any]] = None,
+ **kwargs) -> PreprocessInputResult:
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values'].to(self.dtype)
+ offset = input_mm['offset']
+ image_token_id = input_mm.get('image_token_id', 0)
+ num_pad = input_mm['image_tokens']
+ if isinstance(num_pad, torch.Tensor):
+ num_pad = num_pad.item()
+
+ mm_data = MultiModalTensor(
+ data=pixel_values,
+ start=offset,
+ end=offset + num_pad,
+ meta=dict(image_token_id=image_token_id))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
+
+
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+
+ from transformers.image_processing_utils import select_best_resolution
+
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError('grid_pinpoints should be a list of tuples or lists')
+
+ if not isinstance(image_size, (list, tuple)):
+ image_size = image_size.tolist()
+
+ height, width = select_best_resolution(image_size, grid_pinpoints)
+ return height // patch_size, width // patch_size
+
+
+def unpad_image(tensor, original_size):
+ """Unpads a PyTorch tensor of a padded and resized image."""
+ if not isinstance(original_size, (list, tuple)):
+ original_size = original_size.tolist()
+ original_height, original_width = original_size
+ current_height, current_width = tensor.shape[1:]
+
+ original_aspect_ratio = original_width / original_height
+ current_aspect_ratio = current_width / current_height
+
+ if original_aspect_ratio > current_aspect_ratio:
+ scale_factor = current_width / original_width
+ new_height = int(round(original_height * scale_factor, 7))
+ padding = (current_height - new_height) // 2
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
+ else:
+ scale_factor = current_height / original_height
+ new_width = int(round(original_width * scale_factor, 7))
+ padding = (current_width - new_width) // 2
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
+
+ return unpadded_tensor
+
+
+def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
+ """Calculate the number of patches after the preprocessing for images of
+ any resolution."""
+ from transformers.image_processing_utils import select_best_resolution
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError('grid_pinpoints should be a list of tuples or lists')
+
+ if not isinstance(image_size, (list, tuple)):
+ image_size = image_size.tolist()
+
+ best_resolution = select_best_resolution(image_size, grid_pinpoints)
+ height, width = best_resolution
+
+ num_patches = (height // patch_size) * (width // patch_size)
+ # add the base patch
+ num_patches += 1
+ return num_patches
+
+
+class LlavaNextForConditionalGeneration(LlavaForConditionalGeneration):
+
+ def __init__(self,
+ config: PretrainedConfig,
+ ctx_mgr: StepContextManager,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__(config=config,
+ ctx_mgr=ctx_mgr,
+ dtype=dtype,
+ device=device)
+ self.image_newline = nn.Parameter(
+ torch.empty(config.text_config.hidden_size,
+ dtype=dtype,
+ device=device))
+ self.input_processor = LLavaNextInputProcessor(config, dtype)
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_sizes: torch.Tensor,
+ vision_feature_layer: int,
+ vision_feature_select_strategy: str,
+ ):
+ # ! infer image_num_patches from image_sizes
+ image_num_patches = [
+ image_size_to_num_patches(
+ image_size=imsize,
+ grid_pinpoints=self.config.image_grid_pinpoints,
+ patch_size=self.config.vision_config.image_size,
+ ) for imsize in image_sizes
+ ]
+ if pixel_values.dim() == 5:
+ # stacked if input is
+ # (batch_size, num_patches, num_channels, height, width)
+ _pixel_values_list = [
+ pix_val[:num_patch]
+ for pix_val, num_patch in zip(pixel_values, image_num_patches)
+ ]
+ pixel_values = torch.cat(_pixel_values_list, dim=0)
+ elif pixel_values.dim() != 4:
+ # otherwise has to be stacked from list of
+ # (num_patches, num_channels, height, width)
+ raise ValueError(f'pixel_values of shape {pixel_values.shape}, '
+ 'expect to be of 4 or 5 dimensions')
+
+ selected_image_feature = self.vision_tower(
+ pixel_values, vision_feature_layer=vision_feature_layer)[0]
+ if vision_feature_select_strategy == 'default':
+ selected_image_feature = selected_image_feature[:, 1:]
+ elif vision_feature_select_strategy == 'full':
+ selected_image_feature = selected_image_feature
+ image_features = self.multi_modal_projector(selected_image_feature)
+ image_features = torch.split(image_features, image_num_patches, dim=0)
+ return image_features
+
+ def pack_image_features(self,
+ image_features,
+ image_sizes,
+ vision_feature_select_strategy,
+ image_newline=None):
+
+ new_image_features = []
+ feature_lens = []
+ for image_idx, image_feature in enumerate(image_features):
+ if image_feature.shape[0] > 1:
+ base_image_feature = image_feature[0]
+ image_feature = image_feature[1:]
+ height = width = (self.config.vision_config.image_size //
+ self.config.vision_config.patch_size)
+
+ if vision_feature_select_strategy == 'default':
+ expected_num_patches = height * width
+ elif vision_feature_select_strategy == 'full':
+ expected_num_patches = height * width + 1
+ if expected_num_patches != base_image_feature.shape[0]:
+ raise ValueError('The number of patches is '
+ 'not consistent with the image size.')
+
+ (num_patch_height,
+ num_patch_width) = get_anyres_image_grid_shape(
+ image_sizes[image_idx],
+ self.config.image_grid_pinpoints,
+ self.config.vision_config.image_size,
+ )
+ image_feature = image_feature.view(num_patch_height,
+ num_patch_width, height,
+ width, -1)
+ image_feature = image_feature.permute(4, 0, 2, 1,
+ 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature,
+ image_sizes[image_idx])
+ if image_newline is not None:
+ image_feature = torch.cat(
+ (
+ image_feature,
+ image_newline[:, None, None].expand(
+ *image_feature.shape[:-1], 1).to(
+ image_feature.dtype),
+ ),
+ dim=-1,
+ )
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ image_feature = torch.cat((base_image_feature, image_feature),
+ dim=0)
+ else:
+ image_feature = image_feature[0]
+ if image_newline is not None:
+ image_feature = torch.cat(
+ (image_feature, image_newline[None].to(image_feature)),
+ dim=0)
+ new_image_features.append(image_feature)
+ feature_lens.append(image_feature.size(0))
+ image_features = torch.cat(new_image_features, dim=0)
+ return image_features
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ past_key_values: List[List[torch.Tensor]],
+ attn_metadata: Any = None,
+ pixel_values: torch.Tensor = None,
+ image_sizes: torch.Tensor = None,
+ image_mask: torch.Tensor = None,
+ inputs_embeds: torch.Tensor = None,
+ **kwargs,
+ ):
+ if inputs_embeds is None:
+ image_features = None
+ if pixel_values is not None:
+ vision_feature_layer = self.config.vision_feature_layer
+ select_strategy = self.config.vision_feature_select_strategy
+ image_sizes = image_sizes.tolist()
+ image_features = self.get_image_features(
+ pixel_values,
+ image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=select_strategy)
+ image_features = self.pack_image_features(
+ image_features,
+ image_sizes,
+ vision_feature_select_strategy=select_strategy,
+ image_newline=self.image_newline,
+ )
+ image_features = image_features[None]
+ inputs_embeds = self.language_model.get_input_embeddings()(
+ input_ids)
+ if pixel_values is not None:
+ inputs_embeds.masked_scatter_(image_mask[..., None],
+ image_features)
+
+ return self.language_model.forward(input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ attn_metadata=attn_metadata)
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+ def prepare_inputs_for_generation(
+ self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: torch.Tensor = None,
+ context: StepContext = None,
+ ):
+ """prepare input."""
+ input_ids = context.input_ids
+ position_ids = context.position_ids
+ attn_metadata = context.attn_metadata
+
+ # vision inputs
+ pixel_values = None
+ image_sizes = None
+ image_mask = None
+ if context.input_multimodals is not None:
+ img_mms = [
+ input_mm.get('image', [])
+ for input_mm in context.input_multimodals
+ ]
+ # flatten batch
+ img_mms = [data for im_data in img_mms for data in im_data]
+ if len(img_mms) > 0:
+ image_token_id = img_mms[0].meta['image_token_id']
+ image_mask = input_ids == image_token_id
+ pixel_values = torch.cat(
+ [data.data.flatten(0, 1) for data in img_mms])
+ image_sizes = torch.cat(
+ [data.meta['image_sizes'] for data in img_mms])
+ else:
+ pixel_values = None
+ image_sizes = None
+
+ # get inputs from context
+ vision_embeddings = context.input_embeddings
+ vision_embedding_indexing = context.input_embedding_indexing
+
+ if vision_embeddings is not None and len(vision_embeddings) > 0:
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+ inputs_embeds[:,
+ vision_embedding_indexing, :] = vision_embeddings.to(
+ inputs_embeds)
+
+ return dict(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ attn_metadata=attn_metadata,
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ image_mask=image_mask,
+ inputs_embeds=inputs_embeds,
+ )
+
+
+class LLavaNextInputProcessor(BaseModelInputProcessor):
+ """llava input processor."""
+
+ def __init__(self, config: PretrainedConfig, dtype) -> None:
+ self.config = config
+ self.dtype = dtype
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_multimodals: List[Dict[str, Any]] = None,
+ **kwargs) -> PreprocessInputResult:
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values'].to(self.dtype)
+ image_sizes = input_mm['image_sizes']
+ offset = input_mm['offset']
+ image_token_id = input_mm.get('image_token_id', 0)
+ num_pad = input_mm['image_tokens']
+ if isinstance(num_pad, torch.Tensor):
+ num_pad = num_pad.item()
+
+ mm_data = MultiModalTensor(data=pixel_values,
+ start=offset,
+ end=offset + num_pad,
+ meta=dict(
+ image_sizes=image_sizes,
+ image_token_id=image_token_id))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
diff --git a/lmdeploy/pytorch/models/minicpmv26.py b/lmdeploy/pytorch/models/minicpmv26.py
index 725e97d9d7..e551dda841 100644
--- a/lmdeploy/pytorch/models/minicpmv26.py
+++ b/lmdeploy/pytorch/models/minicpmv26.py
@@ -29,7 +29,8 @@ def __init__(self,
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
-
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
# packed qkv
self.qkv_proj = build_qkv_proj(
hidden_size,
@@ -40,7 +41,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
# rotary embedding
self.apply_rotary_pos_emb = ApplyRotaryEmb()
diff --git a/lmdeploy/pytorch/models/mistral.py b/lmdeploy/pytorch/models/mistral.py
index 04af4c8526..ad27963093 100644
--- a/lmdeploy/pytorch/models/mistral.py
+++ b/lmdeploy/pytorch/models/mistral.py
@@ -420,22 +420,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
param = params_dict[name]
load_weight(param, loaded_weight)
-
-
-class LlavaMistralForCausalLM(MistralForCausalLM):
- """llava forcausallm."""
-
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- """load weights."""
-
- new_weights = dict()
- for key, val in weights:
- if key.startswith('model.vision_tower'):
- continue
- if key.startswith('model.mm_projector'):
- continue
- if key.startswith('model.image_newline'):
- continue
- new_weights[key] = val
-
- super().load_weights(new_weights.items())
diff --git a/lmdeploy/pytorch/models/mllama.py b/lmdeploy/pytorch/models/mllama.py
index 2596fe5299..15b3e9732b 100644
--- a/lmdeploy/pytorch/models/mllama.py
+++ b/lmdeploy/pytorch/models/mllama.py
@@ -3,23 +3,61 @@
import torch
from torch import nn
+from torch.nn import functional as F
from transformers.models.llama import LlamaConfig
-from transformers.models.mllama.modeling_mllama import MllamaTextConfig
+from transformers.models.mllama.modeling_mllama import (MllamaTextConfig,
+ MllamaVisionConfig)
+from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor,
+ PreprocessInputResult)
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
-from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
- SiluAndMul, build_rotary_embedding)
-from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear,
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
+from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, LayerNorm, RMSNorm,
+ RopeType, SiluAndMul, build_rotary_embedding)
+from lmdeploy.pytorch.nn.linear import (build_colwise_linear,
+ build_merged_colwise_linear,
build_qkv_proj, build_rowwise_linear)
from lmdeploy.pytorch.nn.rotary_embedding import Llama3Parameters
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
-from .utils.cudagraph import CudaGraphMixin
+from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin, next_power_of_2
+from .utils.model import DeployModelMixin
MLLAMA_IMAGE_TOKEN_ID = 128256
MLLAMA_IMAGE_TOKEN = '<|image|>'
+def _prepare_aspect_ratio_attention_mask(
+ aspect_ratio_mask: torch.Tensor,
+ num_patches: int,
+ target_length: int,
+ dtype: torch.dtype,
+) -> torch.Tensor:
+ # Expand aspect ratio mask to target_length
+ batch_size, max_num_tiles = aspect_ratio_mask.shape
+ attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1,
+ 1).to(dtype)
+ attention_mask = attention_mask.repeat(1, 1, target_length, 1)
+
+ # Mask padding patches
+ pad_patches = target_length - num_patches
+ attention_mask[:, :, -pad_patches:] = 0
+
+ # Invert the mask (0 -> 1, 1 -> 0)
+ attention_mask = 1 - attention_mask
+
+ # Reshape to 2D and create 4D attention mask
+ # (batch_size, 1, max_num_tiles * target_length,
+ # max_num_tiles * target_length)
+ attention_mask = attention_mask.reshape(batch_size,
+ max_num_tiles * target_length, 1)
+ attention_mask = attention_mask * attention_mask.transpose(
+ -1, -2) * torch.finfo(dtype).min
+ attention_mask = attention_mask.unsqueeze(1)
+
+ return attention_mask
+
+
class LlamaAttention(nn.Module):
"""Rewrite module of LlamaAttention."""
@@ -157,6 +195,7 @@ def __init__(self,
self.head_dim,
num_kv_heads=self.num_key_value_heads,
v_head_size=self.head_dim,
+ causal=False,
)
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
@@ -579,7 +618,542 @@ def get_logits(self, hidden_states: torch.Tensor):
return self.lm_head(hidden_states)
-class MllamaForConditionalGeneration(nn.Module, CudaGraphMixin):
+class MllamaPrecomputedPositionEmbedding(nn.Module):
+ """vis position embedding."""
+
+ def __init__(self,
+ config: MllamaVisionConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.max_num_tiles = config.max_num_tiles
+ self.max_aspect_ratio_id = config.max_aspect_ratio_id
+ self.config = config
+ self.num_patches = (config.image_size // config.patch_size)**2 + 1
+ self.hidden_size = config.hidden_size
+
+ self.gate = nn.Parameter(torch.empty(1, dtype=dtype, device=device))
+
+ # position embedding
+ self.embedding = nn.Parameter(
+ torch.empty(self.num_patches,
+ self.hidden_size,
+ dtype=dtype,
+ device=device))
+
+ # tile position embedding
+ self.tile_embedding = nn.Embedding(self.max_aspect_ratio_id + 1,
+ self.max_num_tiles *
+ self.num_patches * self.hidden_size,
+ dtype=dtype,
+ device=device)
+
+ self._weight_inited = False
+
+ def _init_weight(self):
+ """init weight."""
+ if self._weight_inited:
+ return
+
+ gate_tanh = self.gate.tanh()
+ gated_position_embedding = (1 - gate_tanh) * self.embedding
+ self.gate_tanh = gate_tanh
+ self.gated_position_embedding = gated_position_embedding.view(
+ 1, 1, self.num_patches, self.hidden_size)
+
+ self._weight_inited = True
+
+ def forward(self, hidden_state: torch.Tensor,
+ aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
+ """forward."""
+ self._init_weight()
+
+ # position embeddings
+ hidden_state = hidden_state + self.gated_position_embedding
+
+ # precomputed tile position embeddings
+ tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
+ batch_size = hidden_state.shape[0]
+ tile_position_embedding = tile_position_embedding.reshape(
+ batch_size, self.max_num_tiles, self.num_patches, self.hidden_size)
+ gated_tile_position_embedding = (self.gate_tanh *
+ tile_position_embedding)
+ hidden_state = hidden_state + gated_tile_position_embedding
+
+ return hidden_state
+
+
+class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
+
+ def __init__(self,
+ config: MllamaVisionConfig,
+ is_gated: bool = True,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.max_num_tiles = config.max_num_tiles
+ self.hidden_size = config.hidden_size
+ self.max_aspect_ratio_id = config.max_aspect_ratio_id
+ self.is_gated = is_gated
+
+ self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1,
+ self.max_num_tiles * self.hidden_size,
+ dtype=dtype,
+ device=device)
+ if is_gated:
+ self.gate = nn.Parameter(torch.empty(1, dtype=dtype,
+ device=device))
+
+ self._weight_inited = False
+
+ def _init_weight(self):
+ """init weight."""
+ if self._weight_inited:
+ return
+
+ gate_tanh = self.gate.tanh()
+ self.gate_tanh = gate_tanh
+
+ self._weight_inited = True
+
+ def forward(self, hidden_state: torch.Tensor,
+ aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
+ self._init_weight()
+ embeddings = self.embedding(aspect_ratio_ids)
+ embeddings = embeddings.reshape(-1, self.max_num_tiles, 1,
+ self.hidden_size)
+
+ if self.is_gated:
+ embeddings = embeddings * self.gate_tanh
+
+ hidden_state = hidden_state + embeddings
+ return hidden_state
+
+
+class MllamaVisionAttention(nn.Module):
+ """mllama vision attention."""
+
+ def __init__(self,
+ config: MllamaVisionConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.attention_heads
+ self.head_dim = config.hidden_size // config.attention_heads
+
+ # packed qkv
+ self.qkv_proj = build_qkv_proj(
+ self.embed_dim,
+ num_q_heads=self.num_heads,
+ num_kv_heads=self.num_heads,
+ head_size=self.head_dim,
+ bias=False,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ )
+
+ # o_proj
+ self.o_proj = build_rowwise_linear(self.num_heads * self.head_dim,
+ self.embed_dim,
+ bias=False,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size = hidden_state.size(0)
+ qkv_states = self.qkv_proj(hidden_state)
+ qkv_states = qkv_states.flatten(0, -2)
+ query, key, value = self.qkv_proj.split_qkv(qkv_states)
+
+ query = query.unflatten(0, (batch_size, -1))
+ key = key.unflatten(0, (batch_size, -1))
+ value = value.unflatten(0, (batch_size, -1))
+ q_seq_len = query.shape[1]
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ attn_output = F.scaled_dot_product_attention(query,
+ key,
+ value,
+ attn_mask=attention_mask)
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
+
+ output = self.o_proj(attn_output)
+
+ return output
+
+
+class MllamaVisionMLP(nn.Module):
+ """mllama vision mlp."""
+
+ def __init__(self,
+ config: MllamaVisionConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ from transformers.activations import ACT2FN
+ self.config = config
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = build_colwise_linear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+ self.fc2 = build_rowwise_linear(config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class MllamaVisionEncoderLayer(nn.Module):
+ """vision encoder layer."""
+
+ def __init__(self,
+ config: MllamaVisionConfig,
+ is_gated: bool,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.is_gated = is_gated
+ self.self_attn = MllamaVisionAttention(config,
+ dtype=dtype,
+ device=device)
+ self.mlp = MllamaVisionMLP(config, dtype=dtype, device=device)
+
+ self.input_layernorm = LayerNorm(self.hidden_size,
+ eps=config.norm_eps,
+ dtype=dtype,
+ device=device)
+ self.post_attention_layernorm = LayerNorm(self.hidden_size,
+ eps=config.norm_eps,
+ dtype=dtype,
+ device=device)
+
+ if is_gated:
+ self.gate_attn = nn.Parameter(
+ torch.empty(1, dtype=dtype, device=device))
+ self.gate_ffn = nn.Parameter(
+ torch.empty(1, dtype=dtype, device=device))
+
+ self._weight_inited = not is_gated
+
+ def _init_weight(self):
+ """init weight."""
+ if self._weight_inited:
+ return
+
+ self.gate_attn_tanh = self.gate_attn.tanh()
+ self.gate_ffn_tanh = self.gate_ffn.tanh()
+
+ self._weight_inited = True
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ """forward."""
+ self._init_weight()
+
+ # Self Attention
+ residual = hidden_state
+ hidden_state = self.input_layernorm(hidden_state)
+ hidden_state = self.self_attn(hidden_state,
+ attention_mask=attention_mask)
+ if self.is_gated:
+ hidden_state = self.gate_attn_tanh * hidden_state
+ hidden_state = residual + hidden_state
+
+ # Feed forward
+ residual = hidden_state
+ hidden_state = self.post_attention_layernorm(hidden_state)
+ hidden_state = self.mlp(hidden_state)
+ if self.is_gated:
+ hidden_state = self.gate_ffn_tanh * hidden_state
+ hidden_state = residual + hidden_state
+
+ outputs = hidden_state
+
+ return outputs
+
+
+class MllamaVisionEncoder(nn.Module):
+ """vision encoder."""
+
+ def __init__(self,
+ config: MllamaVisionConfig,
+ num_layers=32,
+ is_gated=False,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([
+ MllamaVisionEncoderLayer(config,
+ is_gated,
+ dtype=dtype,
+ device=device) for _ in range(num_layers)
+ ])
+ self.gradient_checkpointing = False
+ self.config = config
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ """forward."""
+ encoder_states = ()
+ for encoder_layer in self.layers:
+ encoder_states = encoder_states + (hidden_states, )
+ hidden_states = encoder_layer(
+ hidden_state=hidden_states,
+ attention_mask=attention_mask,
+ )
+ encoder_states = encoder_states + (hidden_states, )
+
+ return hidden_states, encoder_states
+
+
+class MllamaVisionModel(nn.Module):
+ """vision model."""
+
+ def __init__(self,
+ config: MllamaVisionConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+
+ self.config = config
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+ self.hidden_size = config.hidden_size
+ self.intermediate_layers_indices = config.intermediate_layers_indices
+ self.dtype = dtype
+
+ self.num_patches = (self.image_size // self.patch_size)**2 + 1
+ self.scale = config.hidden_size**-0.5
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.hidden_size,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding='valid',
+ bias=False,
+ dtype=dtype,
+ device=device,
+ )
+
+ self.class_embedding = nn.Parameter(
+ torch.empty(self.hidden_size, dtype=dtype, device=device))
+ self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
+ config,
+ dtype=dtype,
+ device=device,
+ )
+
+ self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( # noqa: E501
+ config,
+ is_gated=True,
+ dtype=dtype,
+ device=device,
+ )
+ self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( # noqa: E501
+ config,
+ is_gated=True,
+ dtype=dtype,
+ device=device,
+ )
+
+ # layer norms
+ self.layernorm_pre = nn.LayerNorm(
+ self.hidden_size,
+ dtype=dtype,
+ device=device,
+ )
+ self.layernorm_post = nn.LayerNorm(
+ self.hidden_size,
+ dtype=dtype,
+ device=device,
+ )
+
+ # encoders
+ self.transformer = MllamaVisionEncoder(
+ config,
+ config.num_hidden_layers,
+ is_gated=False,
+ dtype=dtype,
+ device=device,
+ )
+ self.global_transformer = MllamaVisionEncoder(
+ config,
+ config.num_global_layers,
+ is_gated=True,
+ dtype=dtype,
+ device=device,
+ )
+
+ def apply_class_embedding(self,
+ hidden_state: torch.Tensor) -> torch.Tensor:
+ batch_size, _, hidden_size = hidden_state.shape
+ class_embedding = self.class_embedding.expand(batch_size, 1,
+ hidden_size)
+ hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
+ return hidden_state
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ aspect_ratio_ids: torch.Tensor,
+ aspect_ratio_mask: torch.Tensor,
+ ):
+ """forward."""
+ (batch_size, num_concurrent_media, num_tiles, num_channels, height,
+ width) = pixel_values.shape
+
+ pixel_values = pixel_values.reshape(
+ batch_size * num_concurrent_media * num_tiles, num_channels,
+ height, width)
+ aspect_ratio_ids = aspect_ratio_ids.reshape(
+ batch_size * num_concurrent_media, -1)
+
+ # Patch embedding
+ patch_embeds = self.patch_embedding(pixel_values.to(self.dtype))
+ hidden_state = patch_embeds.flatten(2).transpose(1, 2)
+
+ # Tile embeddings
+ _, num_patches, dim = hidden_state.shape
+ hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
+ num_tiles, -1, dim)
+ hidden_state = self.pre_tile_positional_embedding(
+ hidden_state, aspect_ratio_ids)
+
+ # Add cls token
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media * num_tiles, num_patches, dim)
+ hidden_state = self.apply_class_embedding(hidden_state)
+ num_patches += 1
+
+ # Position embeddings
+ hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
+ num_tiles, num_patches, dim)
+ hidden_state = self.gated_positional_embedding(hidden_state,
+ aspect_ratio_ids)
+
+ hidden_state = self.layernorm_pre(hidden_state)
+
+ # Compute the number of tokens to pad
+ num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
+ # Compute padding tuple for pad function
+ padding = (
+ 0, 0, 0, num_padding_patches
+ ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
+ # Pad the tensor
+ hidden_state = F.pad(hidden_state, padding, mode='constant', value=0)
+ slice_index = -num_padding_patches if num_padding_patches > 0 else None
+
+ # Prepare attention mask
+ attention_mask = aspect_ratio_mask.reshape(
+ batch_size * num_concurrent_media, -1)
+ attention_mask = _prepare_aspect_ratio_attention_mask(
+ aspect_ratio_mask=attention_mask,
+ num_patches=self.num_patches,
+ target_length=hidden_state.shape[2],
+ dtype=self.dtype,
+ )
+
+ # Apply encoder
+ hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1,
+ dim)
+ output = self.transformer(
+ hidden_state,
+ attention_mask=attention_mask,
+ )
+ hidden_state = output[0]
+
+ hidden_state = self.layernorm_post(hidden_state)
+
+ # Apply global encoder
+ hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
+ num_tiles,
+ num_patches + num_padding_patches,
+ dim)
+ hidden_state = self.post_tile_positional_embedding(
+ hidden_state, aspect_ratio_ids)
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media,
+ num_tiles * (num_patches + num_padding_patches), dim)
+ global_output = self.global_transformer(
+ hidden_state,
+ attention_mask=attention_mask,
+ )
+ hidden_state = global_output[0]
+
+ # Remove padding form hidden state
+ hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
+ num_tiles,
+ num_patches + num_padding_patches,
+ dim)
+ hidden_state = hidden_state[:, :, :slice_index]
+ hidden_state = hidden_state.reshape(batch_size, num_concurrent_media,
+ num_tiles, num_patches, dim)
+
+ # Collect intermediate layer outputs from encoder output
+ all_intermediate_hidden_states = output[1]
+ all_intermediate_hidden_states = [
+ all_intermediate_hidden_states[i]
+ for i in self.intermediate_layers_indices
+ ]
+ intermediate_hidden_states = torch.stack(
+ all_intermediate_hidden_states, dim=-1)
+
+ # Remove padding from intermediate hidden states
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
+ batch_size * num_concurrent_media, num_tiles,
+ num_patches + num_padding_patches, -1)
+ intermediate_hidden_states = intermediate_hidden_states[:, :, :
+ slice_index]
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
+ batch_size, num_concurrent_media, num_tiles, num_patches, -1)
+
+ # Concatenate final hidden state and intermediate hidden states
+ hidden_state = torch.cat([hidden_state, intermediate_hidden_states],
+ dim=-1)
+
+ return hidden_state
+
+
+class MllamaForConditionalGeneration(nn.Module, CudaGraphMixin,
+ DeployModelMixin):
"""rewrote model of MllamaForConditionalGeneration."""
packed_modules_mapping = {
@@ -602,16 +1176,32 @@ def __init__(self,
super().__init__()
self.config = config
self.ctx_mgr = ctx_mgr
+
+ self.vision_model = MllamaVisionModel(
+ config.vision_config,
+ dtype=dtype,
+ device=device,
+ )
# build MllamaForCausalLM
self.language_model = MllamaForCausalLM(config.text_config,
dtype=dtype,
device=device)
+
+ self.multi_modal_projector = build_rowwise_linear(
+ config.vision_config.vision_output_dim,
+ config.text_config.hidden_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ )
self.dtype = dtype
- def flat_encoder_result(self, cross_attention_states: torch.Tensor,
- attn_metadata: Any, input_ids: torch.LongTensor):
+ # preprocessor
+ self.input_processor = MLlamaInputProcessor(self.config, dtype)
+
+ def flat_encoder_result(self, attn_metadata: Any,
+ input_ids: torch.LongTensor):
# since every state share the same shape
- cross_attention_states = torch.cat(cross_attention_states, 0)
full_text_row_masked_out_mask = torch.ones(
(attn_metadata.q_seqlens.sum(), 1), dtype=torch.bool)
start_pos = 0
@@ -621,39 +1211,51 @@ def flat_encoder_result(self, cross_attention_states: torch.Tensor,
full_text_row_masked_out_mask[start_pos:img_id] = False
start_pos += q_seq_len
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
- cross_attention_states.device)
+ input_ids.device)
- return cross_attention_states, full_text_row_masked_out_mask
+ return full_text_row_masked_out_mask
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: List[List[torch.Tensor]],
- cross_attention_states: Optional[torch.Tensor] = None,
+ pixel_values: torch.Tensor = None,
+ aspect_ratio_ids: torch.Tensor = None,
+ aspect_ratio_mask: torch.Tensor = None,
attn_metadata: Any = None,
inputs_embeds: torch.Tensor = None,
cross_attn_metadata: Any = None,
**kwargs,
):
"""model forward, return logits."""
+
if cross_attn_metadata is None:
full_text_row_masked_out_mask = None
# FIXME basically, we want to inference
# text requests and image requests separately
- elif cross_attention_states is None and (
- cross_attn_metadata.kv_seqlens is None
- or int(cross_attn_metadata.kv_seqlens.sum()) == 0):
+ elif pixel_values is None and (cross_attn_metadata.kv_seqlens is None):
full_text_row_masked_out_mask = None
elif cross_attn_metadata.is_decoding:
- cross_attention_states = None
- full_text_row_masked_out_mask = torch.ones(
- (attn_metadata.q_seqlens.sum(), 1),
- dtype=torch.bool,
- device=input_ids.device)
+ full_text_row_masked_out_mask = input_ids.new_ones(
+ input_ids.size(-1), 1)
else:
- cross_attention_states, full_text_row_masked_out_mask = \
- self.flat_encoder_result(cross_attention_states, cross_attn_metadata, input_ids) # noqa
+ full_text_row_masked_out_mask = self.flat_encoder_result(
+ cross_attn_metadata, input_ids) # noqa
+
+ cross_attention_states = None
+ if pixel_values is not None:
+ cross_attention_states = self.vision_model(
+ pixel_values=pixel_values,
+ aspect_ratio_ids=aspect_ratio_ids,
+ aspect_ratio_mask=aspect_ratio_mask,
+ )
+ cross_attention_states = self.multi_modal_projector(
+ cross_attention_states)
+ _, bsz, _, _, image_token_dim = tuple(cross_attention_states.shape)
+ cross_attention_states = cross_attention_states.view(
+ bsz, -1, image_token_dim)
+
hidden_states = self.language_model(
input_ids=input_ids,
position_ids=position_ids,
@@ -670,15 +1272,6 @@ def get_logits(self, hidden_states: torch.Tensor):
"""compute logits of the model output."""
return self.language_model.get_logits(hidden_states)
- def support_cuda_graph(
- self,
- input_ids: torch.Tensor,
- **kwargs,
- ):
- """support cudagraph."""
-
- return False
-
def get_input_embeddings(self):
"""get input embeddings."""
return self.language_model.model.get_input_embeddings()
@@ -694,14 +1287,35 @@ def prepare_inputs_for_generation(
input_ids = context.input_ids
position_ids = context.position_ids
attn_metadata = context.attn_metadata
- cross_attention_states = context.cross_attention_states
- if cross_attention_states is not None:
- cross_attention_states = [
- t.to(input_ids.device) for t in cross_attention_states
- if t is not None
- ]
cross_attn_metadata = context.cross_attn_metadata
+ # cross_attn_metadata is None when inputs without image
+ if cross_attn_metadata is not None and int(
+ cross_attn_metadata.kv_seqlens.sum()) == 0:
+ cross_attn_metadata.kv_seqlens = None
+
+ device = input_ids.device
+
+ # process image input
+ pixel_values = None
+ aspect_ratio_ids = None
+ aspect_ratio_mask = None
+ if context.input_multimodals is not None:
+ pixel_values = []
+ aspect_ratio_ids = []
+ aspect_ratio_mask = []
+ batched_image_data = [
+ input_mm['image'] for input_mm in context.input_multimodals
+ ]
+ for image_data in batched_image_data:
+ for data in image_data:
+ pixel_values.append(data.data)
+ aspect_ratio_ids.append(data.meta['aspect_ratio_ids'])
+ aspect_ratio_mask.append(data.meta['aspect_ratio_mask'])
+ pixel_values = torch.cat(pixel_values, dim=0).to(device)
+ aspect_ratio_ids = torch.cat(aspect_ratio_ids, dim=0).to(device)
+ aspect_ratio_mask = torch.cat(aspect_ratio_mask, dim=0).to(device)
+
# process vision embeddings
vision_embeddings = context.input_embeddings
vision_embedding_indexing = context.input_embedding_indexing
@@ -719,7 +1333,9 @@ def prepare_inputs_for_generation(
past_key_values=past_key_values,
attn_metadata=attn_metadata,
inputs_embeds=inputs_embeds,
- cross_attention_states=cross_attention_states,
+ pixel_values=pixel_values,
+ aspect_ratio_ids=aspect_ratio_ids,
+ aspect_ratio_mask=aspect_ratio_mask,
cross_attn_metadata=cross_attn_metadata,
)
@@ -742,8 +1358,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if ('rotary_emb.cos_cached' in name
or 'rotary_emb.sin_cached' in name):
continue
- if 'vision_model' in name or 'multi_modal_projector' in name:
- continue
if self.config.text_config.tie_word_embeddings and 'lm_head.weight' in name: # noqa
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
@@ -756,3 +1370,161 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
param = params_dict[name]
load_weight(param, loaded_weight)
+
+ def support_cuda_graph(
+ self,
+ input_ids: torch.Tensor,
+ attn_metadata: Any,
+ cross_attn_metadata: Any,
+ **kwargs,
+ ):
+ """support cudagraph."""
+
+ if not attn_metadata.is_decoding:
+ return False
+
+ if cross_attn_metadata is None:
+ return False
+
+ if cross_attn_metadata.kv_seqlens is None:
+ return False
+
+ return True
+
+ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
+ """make cudagraph buffers from forward inputs."""
+ input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta,
+ **kwargs)
+
+ device = graph_meta.device
+ max_batches = graph_meta.max_batchs
+ input_buffers['cross_kv_seqlens'] = torch.zeros(max_batches,
+ dtype=torch.int64,
+ device=device)
+
+ return input_buffers
+
+ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
+ """fill cudagraph buffers from forward inputs."""
+ input_buffers = graph_meta.input_buffers
+
+ new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta,
+ **kwargs)
+
+ attn_metadata = new_inputs['attn_metadata']
+ cross_attn_metadata = new_inputs['cross_attn_metadata']
+ block_offsets = attn_metadata.block_offsets
+ batch_size, _ = block_offsets.size()
+
+ kv_seqlens = cross_attn_metadata.kv_seqlens
+ if kv_seqlens.data_ptr() != input_buffers['cross_kv_seqlens'].data_ptr(
+ ):
+ input_buffers['cross_kv_seqlens'].zero_()
+ input_buffers['cross_kv_seqlens'][:batch_size] = kv_seqlens
+
+ new_batch_size = next_power_of_2(batch_size)
+ cross_attn_metadata.block_offsets = input_buffers[
+ 'block_offsets'][:new_batch_size]
+ cross_attn_metadata.q_start_loc = input_buffers[
+ 'q_start_loc'][:new_batch_size]
+ cross_attn_metadata.q_seqlens = input_buffers[
+ 'q_seqlens'][:new_batch_size]
+ cross_attn_metadata.kv_seqlens = input_buffers[
+ 'cross_kv_seqlens'][:new_batch_size]
+
+ new_inputs['cross_attn_metadata'] = cross_attn_metadata
+ return new_inputs
+
+ def update_model_metas(self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None):
+ """update model meta."""
+ model_metas = context.model_metas
+ if model_metas is None:
+ batch_size = context.q_seqlens.size(0)
+ model_metas = [dict(cross_kv_len=0) for _ in range(batch_size)]
+
+ if context.is_decoding:
+ return model_metas
+
+ vision_inputs = context.vision_inputs
+ if vision_inputs is None:
+ return model_metas
+
+ input_mms = vision_inputs.input_multimodals
+ if input_mms is None:
+ return model_metas
+
+ config = self.config.vision_config
+ image_size = config.image_size
+ patch_size = config.patch_size
+ wh = image_size // patch_size
+ img_kv_len = wh * wh + 1
+ img_kv_len = img_kv_len * 4
+
+ new_model_metas = []
+ for idx, input_mm in enumerate(input_mms):
+ if input_mm is None:
+ new_model_metas.append(model_metas[idx])
+ images = input_mm['image']
+ num_img = len(images)
+
+ cross_kv_len = 0
+ if model_metas[idx] is not None:
+ cross_kv_len = model_metas[idx].get('cross_kv_len',
+ cross_kv_len)
+ cross_kv_len += img_kv_len * num_img
+ new_model_metas.append(dict(cross_kv_len=cross_kv_len))
+
+ return model_metas
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+
+class MLlamaInputProcessor(BaseModelInputProcessor):
+ """mllama input processor."""
+
+ def __init__(self, config: LlamaConfig, dtype: torch.dtype) -> None:
+ self.config = config
+ self.dtype = dtype
+
+ vision_config = self.config.vision_config
+ image_size = vision_config.image_size
+ patch_size = vision_config.patch_size
+ wh = image_size // patch_size
+ encoder_len = wh * wh + 1
+ encoder_len = encoder_len * 4
+ self.encoder_len = encoder_len
+
+ def preprocess_input(self, input_ids, input_multimodals, **kwargs):
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values']
+ aspect_ratio_ids = input_mm['aspect_ratio_ids']
+ aspect_ratio_mask = input_mm['aspect_ratio_mask']
+ offset = input_mm['offset']
+
+ if pixel_values.dtype != self.dtype:
+ pixel_values = pixel_values.to(self.dtype)
+
+ mm_data = MultiModalTensor(
+ data=pixel_values,
+ start=offset,
+ end=offset + 1,
+ encoder_len=self.encoder_len,
+ meta=dict(aspect_ratio_ids=aspect_ratio_ids,
+ aspect_ratio_mask=aspect_ratio_mask))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py
index 1059bfee4e..e7b460026a 100644
--- a/lmdeploy/pytorch/models/module_map.py
+++ b/lmdeploy/pytorch/models/module_map.py
@@ -85,14 +85,10 @@
# llava
MODULE_MAP.update(
{
- 'LlavaLlamaForCausalLM':
- f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlavaLlamaForCausalLM',
- 'LlavaMistralForCausalLM':
- f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.LlavaMistralForCausalLM',
'LlavaForConditionalGeneration':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration', # noqa: E501
'LlavaNextForConditionalGeneration': # noqa: E501
- f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration'
+ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaNextForConditionalGeneration' # noqa: E501
})
# qwen
@@ -158,7 +154,7 @@
# phi3 vision
MODULE_MAP.update({
'Phi3VForCausalLM':
- f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.Phi3VForCausalLM',
+ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3_v.Phi3VForCausalLM',
})
# phi-3.5-moe
diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py
index 9da1b9f4ea..9604b19af5 100644
--- a/lmdeploy/pytorch/models/patch.py
+++ b/lmdeploy/pytorch/models/patch.py
@@ -8,6 +8,7 @@
import torch
from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_utils import load_state_dict
from lmdeploy.utils import get_logger
@@ -250,6 +251,10 @@ def add_adapters(model: torch.nn.Module,
ranks, scalings = get_ranks_and_scalings(target_name,
adapter_cfgs,
device=device)
+ # split in case target_name has '.' like 'attention.wo'
+ # which cannot be used as name of a module
+ # and it's not aligned with key in model.packed_modules_mapping
+ target_name = target_name.split('.')[-1]
found_mods, pack_idx = find_all_target(model, target_name)
sum_rank = ranks.sum().item()
@@ -295,7 +300,9 @@ def add_adapters(model: torch.nn.Module,
for name, path in adapters.items():
adapter_id = adapter_id_map[name]
checkpoint_path = f'{path}/adapter_model.bin'
- state_dict = torch.load(checkpoint_path, map_location=device)
+ if not osp.exists(checkpoint_path):
+ checkpoint_path = f'{path}/adapter_model.safetensors'
+ state_dict = load_state_dict(checkpoint_path, map_location=device)
if hasattr(model, 'load_lora_weights'):
model.load_lora_weights(state_dict.items(), adapter_id=adapter_id)
diff --git a/lmdeploy/pytorch/models/phi3.py b/lmdeploy/pytorch/models/phi3.py
index f9477fdab8..288fdf3b19 100644
--- a/lmdeploy/pytorch/models/phi3.py
+++ b/lmdeploy/pytorch/models/phi3.py
@@ -435,7 +435,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
param = params_dict[name]
load_weight(param, loaded_weight)
-
-
-class Phi3VForCausalLM(Phi3ForCausalLM):
- ...
diff --git a/lmdeploy/pytorch/models/phi3_v.py b/lmdeploy/pytorch/models/phi3_v.py
new file mode 100644
index 0000000000..c4bf72c767
--- /dev/null
+++ b/lmdeploy/pytorch/models/phi3_v.py
@@ -0,0 +1,476 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig
+
+from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor,
+ PreprocessInputResult)
+from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
+from lmdeploy.pytorch.nn.linear import build_rowwise_linear
+from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
+
+from .phi3 import Phi3ForCausalLM, Phi3Model
+from .utils.model import DeployModelMixin
+
+CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(attention_dropout=0.0,
+ dropout=0.0,
+ hidden_act='quick_gelu',
+ hidden_size=1024,
+ image_size=336,
+ initializer_factor=1.0,
+ initializer_range=0.02,
+ intermediate_size=4096,
+ layer_norm_eps=1e-05,
+ num_attention_heads=16,
+ num_channels=3,
+ num_hidden_layers=24,
+ patch_size=14,
+ projection_dim=768)
+
+
+class Phi3ImageEmbedding(nn.Module):
+ """image embedding."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ wte=None,
+ dtype: torch.dtype = None,
+ device: torch.device = None,
+ **kwargs):
+ super().__init__()
+ self.config = config
+ hidden_size = config.n_embd if hasattr(
+ config, 'n_embd') else config.hidden_size
+
+ self.wte = wte
+
+ if (isinstance(config.img_processor, dict) and
+ config.img_processor.get('name', None) == 'clip_vision_model'):
+ assert 'model_name' in config.img_processor, (
+ 'model_name must be provided for CLIPVisionModel')
+ assert 'image_dim_out' in config.img_processor, (
+ 'image_dim_out must be provided for CLIPVisionModel')
+ assert 'num_img_tokens' in config.img_processor, (
+ 'num_img_tokens must be provided for CLIPVisionModel')
+ assert config.img_processor[
+ 'model_name'] == 'openai/clip-vit-large-patch14-336'
+ clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
+ self.img_processor = CLIPVisionModel(clip_config).to(device).to(
+ dtype)
+ image_dim_out = config.img_processor['image_dim_out']
+ self.num_img_tokens = config.img_processor['num_img_tokens']
+ else:
+ raise NotImplementedError(
+ f'img_processor = {config.img_processor}, not implemented')
+
+ self.image_dim_out = image_dim_out
+ self.img_sizes = None
+
+ self.use_hd_transform = kwargs.get('use_hd_transform', False)
+ self.with_learnable_separator = kwargs.get('with_learnable_separator',
+ False)
+ self.hd_transform_order = kwargs.get('hd_transform_order', 'glb_sub')
+ # with_hd_transform and with_learnable_separator should have same value
+ assert (self.use_hd_transform == self.with_learnable_separator), (
+ 'use_hd_transform and with_learnable_separator '
+ 'should have same value')
+ if self.with_learnable_separator:
+ assert self.use_hd_transform, (
+ 'learnable separator is only for hd transform')
+ # 1024 * 4, merge spatial to channel dimension
+ self.glb_GN = nn.Parameter(
+ torch.empty([1, 1, self.image_dim_out * 4],
+ dtype=dtype,
+ device=device))
+ self.sub_GN = nn.Parameter(
+ torch.empty([1, 1, 1, self.image_dim_out * 4],
+ dtype=dtype,
+ device=device))
+
+ projection_cls = kwargs.get('projection_cls', 'linear')
+ if projection_cls == 'linear':
+ self.img_projection = nn.Linear(image_dim_out,
+ hidden_size,
+ dtype=dtype,
+ device=device)
+ elif projection_cls == 'mlp' and self.use_hd_transform:
+ dim_projection = hidden_size
+ depth = 2
+ layers = [
+ nn.Linear(image_dim_out * 4,
+ dim_projection,
+ dtype=dtype,
+ device=device)
+ ]
+ for _ in range(1, depth):
+ layers.extend([
+ nn.GELU(),
+ nn.Linear(dim_projection,
+ dim_projection,
+ dtype=dtype,
+ device=device)
+ ])
+ self.img_projection = nn.Sequential(*layers)
+ elif projection_cls == 'mlp':
+ dim_projection = hidden_size
+ depth = 2
+ layers = [
+ nn.Linear(image_dim_out,
+ dim_projection,
+ dtype=dtype,
+ device=device)
+ ]
+ for _ in range(1, depth):
+ layers.extend([
+ nn.GELU(),
+ nn.Linear(dim_projection,
+ dim_projection,
+ dtype=dtype,
+ device=device)
+ ])
+ self.img_projection = nn.Sequential(*layers)
+ else:
+ raise NotImplementedError(
+ f'projection_cls = {projection_cls}, not implemented')
+
+ self.vocab_size = config.vocab_size
+ self.img_features = None
+
+ if isinstance(config.img_processor, dict):
+ self.layer_idx = config.img_processor.get('layer_idx', -2)
+ self.type_feature = config.img_processor.get(
+ 'type_feature', 'patch')
+ else:
+ self.layer_idx = -2
+ self.type_feature = 'patch'
+
+ def get_img_features(self,
+ img_embeds: torch.FloatTensor) -> torch.FloatTensor:
+ LAYER_IDX = self.layer_idx
+ TYPE_FEATURE = self.type_feature
+
+ img_processor_output = self.img_processor(img_embeds,
+ output_hidden_states=True)
+ img_feature = img_processor_output.hidden_states[LAYER_IDX]
+
+ if TYPE_FEATURE == 'patch':
+ patch_feature = img_feature[:, 1:]
+ return patch_feature
+
+ if TYPE_FEATURE == 'cls_patch':
+ return img_feature
+
+ raise NotImplementedError
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ pixel_values: torch.FloatTensor,
+ image_sizes=None,
+ image_mask: torch.Tensor = None,
+ ) -> torch.FloatTensor:
+ """forward."""
+
+ target_device = pixel_values.device
+ target_dtype = pixel_values.dtype
+
+ img_embeds = pixel_values
+ img_sizes = image_sizes
+ img_sizes = img_sizes.cpu()
+
+ if self.use_hd_transform and img_sizes is not None and len(img_sizes):
+ assert img_embeds.ndim == 5, f'img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform' # noqa E501
+ # img_embeds: (num_images, max_num_crops, 3, H, W)
+ # img_sizes: (num_images, 2).view(1, -1)
+
+ bs = img_embeds.shape[0]
+ # Nx(HW)xC
+ img_features = self.get_img_features(img_embeds.flatten(0, 1))
+ base_feat_height = base_feat_width = int(
+ img_features.shape[1]**0.5)
+
+ assert base_feat_height == 24 and base_feat_width == 24, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect 24x24 features for hd transform' # noqa E501
+
+ # bs x max_num_crops x (24x24) x C
+ img_features = img_features.view(
+ bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
+ C = self.image_dim_out
+ H = base_feat_height
+
+ output_imgs = []
+ output_len = []
+ # training is tensor, inference is list
+ if isinstance(img_sizes, torch.Tensor):
+ img_sizes = img_sizes.view(-1, 2)
+ for _bs in range(bs):
+ h, w = img_sizes[_bs]
+ h = h // 336
+ w = w // 336
+ B_ = h * w
+
+ # 1 x (24x24) x 1024
+ global_img_feature = img_features[_bs, :1]
+
+ # 1 x 12 x 12 x 4096
+ glb_img = global_img_feature.reshape(
+ 1, H // 2, 2, H // 2, 2,
+ C).permute(0, 1, 3, 2, 4,
+ 5).reshape(1, H // 2, H // 2, 4 * C)
+ temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1)
+
+ # 1 x 156 x 4096
+ glb_img = torch.cat([glb_img, temp_glb_GN],
+ dim=2).reshape(1, -1, 4 * C)
+
+ # (max_num_crops-1) x (12x12) x C
+ sub_img = img_features[_bs, 1:]
+ # 16x574x1024
+ # get rid of padding sub_img
+ sub_img = sub_img[:B_]
+
+ # (num_crops, 12, 2, 12, 2, 1024)
+ # ->(num_crops, 12, 12, 2, 2, 1024)
+ # -> (num_crops, 12*12, 4*1024)
+ sub_img = (sub_img.reshape(B_, H // 2, 2, H // 2, 2,
+ C).permute(0, 1, 3, 2, 4, 5))
+ sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute(
+ 0, 1, 3, 2, 4, 5).reshape(1, h * 12, w * 12, 4 * C)
+ temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1)
+ sub_img = torch.cat([sub_img, temp_sub_GN],
+ dim=2).reshape(1, -1, 4 * C)
+ # (1, num_img_tokens, 1024*4)
+
+ # glb + sub
+ if self.hd_transform_order == 'glb_sub':
+ output_imgs.append(
+ torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
+ elif self.hd_transform_order == 'sub_glb':
+ output_imgs.append(
+ torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
+ else:
+ raise NotImplementedError(
+ f'hd_transform_order = {self.hd_transform_order}'
+ ) # noqa E501
+
+ temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
+ assert temp_len == output_imgs[-1].shape[
+ 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' # noqa E501
+ output_len.append(temp_len)
+
+ img_set_tensor = []
+ for _output_img in output_imgs:
+ img_feature_proj = self.img_projection(
+ _output_img.to(target_device).to(target_dtype))
+ img_feature_proj = img_feature_proj.flatten(0, 1)
+ img_set_tensor.append(img_feature_proj)
+ img_set_tensor = torch.cat(img_set_tensor)[None]
+ elif img_embeds.ndim == 4:
+ tt = (self.get_img_features(img_embeds).to(target_device).to(
+ target_dtype).reshape(-1, self.image_dim_out))
+ img_set_tensor = self.img_projection(
+ tt) # adapted visual features.
+ elif img_embeds.ndim == 3:
+ tt = (img_embeds.to(target_device).to(target_dtype).view(
+ -1, self.image_dim_out))
+ img_set_tensor = self.img_projection(
+ tt) # adapted visual features.
+ else:
+ raise NotImplementedError
+
+ hidden_states = self.wte(input_ids)
+
+ hidden_states.masked_scatter_(image_mask[..., None], img_set_tensor)
+
+ return hidden_states
+
+
+class Phi3VModel(Phi3Model):
+ """phi3v model."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__(config=config, dtype=dtype, device=device)
+
+ self.vision_embed_tokens = None
+ if isinstance(config.embd_layer, dict):
+ # vision embedding layer
+ embedding_config = {
+ 'embedding_cls': config.embd_layer['embedding_cls'],
+ **config.embd_layer
+ }
+ self.vision_embed_tokens = Phi3ImageEmbedding(
+ config,
+ wte=self.embed_tokens,
+ dtype=dtype,
+ device=device,
+ **embedding_config)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ attn_metadata: Any = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ image_mask: torch.Tensor = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ """Rewrite of LlamaModel.forward."""
+
+ if inputs_embeds is None and pixel_values is not None:
+ inputs_embeds = self.vision_embed_tokens(
+ input_ids,
+ pixel_values,
+ image_sizes,
+ image_mask,
+ )
+
+ return super().forward(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ attn_metadata=attn_metadata,
+ inputs_embeds=inputs_embeds,
+ )
+
+
+class Phi3VForCausalLM(Phi3ForCausalLM, DeployModelMixin):
+
+ def __init__(self,
+ config: PretrainedConfig,
+ ctx_mgr: StepContextManager,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__(config, ctx_mgr, dtype=dtype, device=device)
+ self.config = config
+ self.ctx_mgr = ctx_mgr
+ # build model
+ self.model = Phi3VModel(config, dtype=dtype, device=device)
+ # build lm_head
+ self.lm_head = build_rowwise_linear(config.hidden_size,
+ config.vocab_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+
+ self.input_processor = Phi3VInputProcessor(config, dtype)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ past_key_values: List[List[torch.Tensor]],
+ attn_metadata: Any = None,
+ pixel_values: torch.Tensor = None,
+ image_sizes: torch.Tensor = None,
+ image_mask: torch.Tensor = None,
+ inputs_embeds: torch.Tensor = None,
+ **kwargs,
+ ):
+ """forward."""
+ hidden_states = self.model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ attn_metadata=attn_metadata,
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ image_mask=image_mask,
+ inputs_embeds=inputs_embeds,
+ )
+ return hidden_states
+
+ def prepare_inputs_for_generation(
+ self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: torch.Tensor = None,
+ context: StepContext = None,
+ ):
+ """prepare input."""
+ output = super().prepare_inputs_for_generation(
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ context=context)
+
+ # vision inputs
+ pixel_values = None
+ if context.input_multimodals is not None:
+ input_mms = [
+ input_mm.get('image', [])
+ for input_mm in context.input_multimodals
+ ]
+ # flatten batch
+ input_mms = [data for im_data in input_mms for data in im_data]
+ if len(input_mms) > 0:
+ pixel_values = torch.cat([data.data for data in input_mms])
+ image_sizes = torch.cat(
+ [data.meta['image_sizes'] for data in input_mms])
+ image_token_id = input_mms[0].meta['image_token_id']
+ image_mask = output['input_ids'] == image_token_id
+ output['pixel_values'] = pixel_values
+ output['image_sizes'] = image_sizes
+ output['image_mask'] = image_mask
+
+ return output
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ """load weights."""
+ super().load_weights(weights)
+
+ vis_prefix = 'vision_embed_tokens.'
+ params_dict = dict(self.named_parameters())
+ for name, loaded_weight in weights:
+ if not (vis_prefix in name):
+ continue
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+
+class Phi3VInputProcessor(BaseModelInputProcessor):
+ """Phi3V input processor."""
+
+ def __init__(self, config: PretrainedConfig, dtype) -> None:
+ self.config = config
+ self.dtype = dtype
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_multimodals: List[Dict[str, Any]] = None,
+ **kwargs) -> PreprocessInputResult:
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values'].to(self.dtype)
+ image_sizes = input_mm['image_sizes']
+ offset = input_mm['offset']
+ image_token_id = input_mm.get('image_token_id', 0)
+ num_pad = input_mm['image_tokens']
+ if isinstance(num_pad, torch.Tensor):
+ num_pad = num_pad.item()
+
+ mm_data = MultiModalTensor(data=pixel_values,
+ start=offset,
+ end=offset + num_pad,
+ meta=dict(
+ image_sizes=image_sizes,
+ image_token_id=image_token_id))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
diff --git a/lmdeploy/pytorch/models/qwen2.py b/lmdeploy/pytorch/models/qwen2.py
index 82be75e167..38773c21e1 100644
--- a/lmdeploy/pytorch/models/qwen2.py
+++ b/lmdeploy/pytorch/models/qwen2.py
@@ -29,7 +29,8 @@ def __init__(self,
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
-
+ num_replicate_kv_heads = getattr(config,
+ 'num_replicate_key_value_heads', 1)
# packed qkv
self.qkv_proj = build_qkv_proj(
hidden_size,
@@ -40,7 +41,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
- )
+ num_replicate_kv_heads=num_replicate_kv_heads)
# rotary embedding
self.apply_rotary_pos_emb = ApplyRotaryEmb()
diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py
index b10baaa4d5..4e2b1017b5 100644
--- a/lmdeploy/pytorch/models/qwen2_vl.py
+++ b/lmdeploy/pytorch/models/qwen2_vl.py
@@ -1,18 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Any, Callable, Iterable, List, Optional, Tuple
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig
+from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor,
+ PreprocessInputResult)
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
-from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
- SiluAndMul, build_rotary_embedding)
-from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear,
+from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
+from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, FlashAttention,
+ LayerNorm, RMSNorm, RopeType, SiluAndMul,
+ build_rotary_embedding)
+from lmdeploy.pytorch.nn.linear import (build_colwise_linear,
+ build_merged_colwise_linear,
build_qkv_proj, build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin, next_power_of_2
+from .utils.model import DeployModelMixin
def _apply_mrope_selection(hidden_states: torch.Tensor,
@@ -337,7 +343,337 @@ def get_input_embeddings(self):
return self.embed_tokens
-class Qwen2VLForConditionalGeneration(nn.Module, CudaGraphMixin):
+class PatchEmbed(nn.Module):
+ """Patch Embed."""
+
+ def __init__(self,
+ patch_size: int = 14,
+ temporal_patch_size: int = 2,
+ in_channels: int = 3,
+ embed_dim: int = 1152,
+ dtype: torch.dtype = None,
+ device: torch.device = None) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.in_channels = in_channels
+ self.embed_dim = embed_dim
+
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
+ self.proj = nn.Conv3d(in_channels,
+ embed_dim,
+ kernel_size=kernel_size,
+ stride=kernel_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(-1, self.in_channels,
+ self.temporal_patch_size,
+ self.patch_size, self.patch_size)
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(
+ -1, self.embed_dim)
+ return hidden_states
+
+
+class VisionRotaryEmbedding(nn.Module):
+ """vision rotary embedding."""
+
+ def __init__(self,
+ dim: int,
+ theta: float = 10000.0,
+ device: torch.device = None) -> None:
+ super().__init__()
+ inv_freq = 1.0 / (theta**(
+ torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim))
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen,
+ device=self.inv_freq.device,
+ dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+
+class VisionAttention(nn.Module):
+ """Vision attention."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
+ dim = config.embed_dim
+ num_heads = config.num_heads
+ head_dim = dim // num_heads
+ self.head_dim = head_dim
+
+ # packed qkv
+ self.qkv = build_qkv_proj(
+ dim,
+ num_q_heads=num_heads,
+ num_kv_heads=num_heads,
+ head_size=head_dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ )
+
+ # rotary embedding
+ self.apply_rotary_pos_emb = ApplyRotaryEmb()
+
+ # attention
+ self.attention = FlashAttention(
+ num_heads,
+ head_dim,
+ causal=False,
+ )
+
+ # o_proj
+ self.proj = build_rowwise_linear(dim,
+ dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor]
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ # qkv proj
+ qkv_states = self.qkv(hidden_states)
+ # (-1, heads, head_dim)
+ qkv_states = qkv_states.flatten(0, -2)
+ q, k, v = self.qkv.split_qkv(qkv_states)
+
+ cos, sin = rotary_pos_emb
+ q, k = self.apply_rotary_pos_emb(q, k, cos, sin)
+
+ attn_output = self.attention(
+ q,
+ k,
+ v,
+ q_start_loc=cu_seqlens[:-1],
+ q_seqlens=cu_seqlens[1:] - cu_seqlens[:-1],
+ )
+
+ attn_output = attn_output.reshape(seq_length, -1)
+
+ # o proj
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class VisionMlp(nn.Module):
+ """Vision mlp."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ from transformers.activations import ACT2FN
+ dim = config.embed_dim
+ hidden_dim = int(config.embed_dim * config.mlp_ratio)
+ quantization_config = getattr(config, 'quantization_config', None)
+ # gate up
+ self.fc1 = build_colwise_linear(
+ dim,
+ hidden_dim,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+
+ # silu and mul
+ if config.hidden_act in [
+ 'gelu', 'gelu_fast', 'quick_gelu', 'gelu_python'
+ ]:
+ self.act = nn.GELU()
+ else:
+ self.act = ACT2FN[config.hidden_act]
+
+ # down
+ self.fc2 = build_rowwise_linear(hidden_dim,
+ dim,
+ bias=True,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, x):
+ """forward."""
+ return self.fc2(self.act(self.fc1(x)))
+
+
+class Qwen2VLVisionBlock(nn.Module):
+ """Vision block."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ layer_idx: int,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.norm1 = LayerNorm(config.embed_dim,
+ eps=1e-6,
+ dtype=dtype,
+ device=device)
+ self.norm2 = LayerNorm(config.embed_dim,
+ eps=1e-6,
+ dtype=dtype,
+ device=device)
+
+ self.attn = VisionAttention(config, dtype=dtype, device=device)
+
+ self.mlp = VisionMlp(config, dtype=dtype, device=device)
+
+ def forward(self,
+ hidden_states,
+ cu_seqlens,
+ rotary_pos_emb,
+ residual: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ else:
+ hidden_states, residual = self.norm1(hidden_states, residual)
+
+ hidden_states = self.attn(hidden_states,
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb)
+
+ hidden_states, residual = self.norm2(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ return hidden_states, residual
+
+
+class PatchMerger(nn.Module):
+ """PatchMerger."""
+
+ def __init__(self,
+ dim: int,
+ context_dim: int,
+ spatial_merge_size: int = 2,
+ dtype: torch.dtype = None,
+ device: torch.device = None) -> None:
+ super().__init__()
+ self.hidden_size = context_dim * (spatial_merge_size**2)
+ self.ln_q = nn.LayerNorm(context_dim,
+ eps=1e-6,
+ dtype=dtype,
+ device=device)
+ self.mlp = nn.Sequential(
+ nn.Linear(self.hidden_size,
+ self.hidden_size,
+ dtype=dtype,
+ device=device),
+ nn.GELU(),
+ nn.Linear(self.hidden_size, dim, dtype=dtype, device=device),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
+ return x
+
+
+class Qwen2VisionTransformerPretrainedModel(nn.Module):
+ """Vision transformer."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.spatial_merge_size = config.spatial_merge_size
+
+ self.patch_embed = PatchEmbed(
+ patch_size=config.patch_size,
+ temporal_patch_size=config.temporal_patch_size,
+ in_channels=config.in_channels,
+ embed_dim=config.embed_dim,
+ dtype=dtype,
+ device=device,
+ )
+
+ head_dim = config.embed_dim // config.num_heads
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2,
+ device=device)
+
+ self.blocks = nn.ModuleList([
+ Qwen2VLVisionBlock(config, layer_idx, dtype=dtype, device=device)
+ for layer_idx in range(config.depth)
+ ])
+ self.merger = PatchMerger(dim=config.hidden_size,
+ context_dim=config.embed_dim,
+ spatial_merge_size=config.spatial_merge_size,
+ dtype=dtype,
+ device=device)
+
+ def rot_pos_emb(self, grid_thw):
+ """rotary position embedding."""
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(
+ torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ return rotary_pos_emb
+
+ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
+ rotary_pos_emb: torch.Tensor) -> torch.Tensor:
+ """forward."""
+ hidden_states = self.patch_embed(hidden_states)
+ cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)
+
+ residual = None
+ for blk in self.blocks:
+ hidden_states, residual = blk(hidden_states,
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ residual=residual)
+
+ hidden_states = hidden_states + residual
+
+ return self.merger(hidden_states)
+
+
+class Qwen2VLForConditionalGeneration(nn.Module, DeployModelMixin,
+ CudaGraphMixin):
"""ModelForCausalLM."""
packed_modules_mapping = {
@@ -360,6 +696,16 @@ def __init__(self,
super().__init__()
self.config = config
self.ctx_mgr = ctx_mgr
+
+ # preprocessor
+ self.input_processor = Qwen2VLInputProcessor(self.config)
+
+ # build vision model
+ self.visual = Qwen2VisionTransformerPretrainedModel(
+ config.vision_config,
+ dtype=dtype,
+ device=device,
+ )
# build model
self.model = Qwen2Model(config, dtype=dtype, device=device)
# build lm_head
@@ -377,9 +723,26 @@ def forward(
attn_metadata: Any = None,
inputs_embeds: torch.Tensor = None,
mrope_position_ids: torch.Tensor = None,
+ pixel_values: torch.Tensor = None,
+ vis_cu_seqlens: torch.Tensor = None,
+ vis_pos_emb: torch.Tensor = None,
+ image_mask: torch.Tensor = None,
**kwargs,
):
"""model forward, return logits."""
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+ if pixel_values is not None:
+ dtype = inputs_embeds.dtype
+ pixel_values = pixel_values.to(dtype)
+ vis_pos_emb = (vis_pos_emb[0].to(dtype),
+ vis_pos_emb[1].to(dtype))
+ image_embeds = self.visual(pixel_values,
+ cu_seqlens=vis_cu_seqlens,
+ rotary_pos_emb=vis_pos_emb)
+ inputs_embeds = inputs_embeds.masked_scatter(
+ image_mask[..., None], image_embeds)
+
hidden_states = self.model(
input_ids=input_ids,
position_ids=position_ids,
@@ -416,6 +779,36 @@ def prepare_inputs_for_generation(
position_ids = context.position_ids
attn_metadata = context.attn_metadata
+ pixel_values = None
+ vis_cu_seqlens = None
+ vis_pos_emb = None
+ image_mask = None
+ if context.input_multimodals is not None:
+ image_data = [
+ input_mm['image'] for input_mm in context.input_multimodals
+ ]
+
+ if len(image_data) > 0:
+ # flatten batch
+ image_data = [
+ data for im_data in image_data for data in im_data
+ ]
+ pixel_values = torch.cat([data.data for data in image_data])
+ image_token_id = image_data[0].meta['image_token_id']
+ image_mask = input_ids == image_token_id
+ grid_thw = torch.cat(
+ [data.meta['grid_thw'] for data in image_data]).cpu()
+ vis_pos_emb = self.visual.rot_pos_emb(grid_thw)
+ vis_cu_seqlens = torch.repeat_interleave(
+ grid_thw[:, 1] * grid_thw[:, 2],
+ grid_thw[:, 0]).to(pixel_values.device)
+ vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0,
+ dtype=torch.int32)
+ vis_pos_emb = vis_pos_emb.repeat(1, 2)
+ vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin())
+
+ mrope_position_ids = getattr(context, 'mrope_position_ids', None)
+
# process vision embeddings
vision_embeddings = context.input_embeddings
vision_embedding_indexing = context.input_embedding_indexing
@@ -433,7 +826,11 @@ def prepare_inputs_for_generation(
past_key_values=past_key_values,
attn_metadata=attn_metadata,
inputs_embeds=inputs_embeds,
- mrope_position_ids=context.mrope_position_ids,
+ mrope_position_ids=mrope_position_ids,
+ pixel_values=pixel_values,
+ vis_cu_seqlens=vis_cu_seqlens,
+ vis_pos_emb=vis_pos_emb,
+ image_mask=image_mask,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -450,8 +847,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
- if 'visual' in name:
- continue
if 'rotary_emb.inv_freq' in name:
continue
if ('rotary_emb.cos_cached' in name
@@ -467,8 +862,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
load_weight(param, loaded_weight, shard_id=shard_id)
break
else:
- param = params_dict[name]
- load_weight(param, loaded_weight)
+ if '.qkv.' in name:
+ param = params_dict[name]
+ q, k, v = param.weight_spliter(loaded_weight)
+ load_weight(param, q, shard_id='q')
+ load_weight(param, k, shard_id='k')
+ load_weight(param, v, shard_id='v')
+ else:
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
"""make cudagraph buffers from forward inputs."""
@@ -510,3 +912,130 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
'mrope_position_ids']
return new_inputs
+
+ def _update_model_meta_decoding(self, context: StepContext):
+ """update model meta for decoding."""
+ model_metas = context.model_metas
+ position_ids = context.position_ids
+
+ mrope_deltas = [meta['mrope_delta'] for meta in model_metas]
+ mrope_deltas = position_ids.new_tensor(mrope_deltas)
+ mrope_position_ids = position_ids + mrope_deltas[None]
+ mrope_position_ids = mrope_position_ids.expand(3, -1)
+
+ context.mrope_position_ids = mrope_position_ids
+ return model_metas
+
+ def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device):
+ """get mrope ids."""
+ t, h, w = grid_thw
+ h //= 2
+ w //= 2
+ stride = torch.tensor([h * w, w, 1], device=device)[:, None]
+ size = torch.tensor([t, h, w], device=device)[:, None]
+ pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1)
+ pos_ids = pos_ids // stride % size
+ return pos_ids
+
+ def _update_model_meta_prefilling(self, context: StepContext):
+ """update model meta for prefilling."""
+ model_metas = context.model_metas
+ input_multimodals = context.input_multimodals
+ if input_multimodals is None:
+ input_multimodals = [None] * len(model_metas)
+ position_ids = context.position_ids
+ batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist())
+ mrope_position_ids = []
+ new_model_metas = []
+ for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas,
+ input_multimodals):
+ images = []
+ if input_mm is not None:
+ images = input_mm['image']
+ if model_meta is None or 'mrope_delta' not in model_meta:
+ mrope_delta = 0
+ else:
+ mrope_delta = model_meta['mrope_delta']
+
+ pos_start = pos_ids[0].item()
+ mrope_pos_ids = pos_ids + mrope_delta
+ mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone()
+ for img in images:
+ grid_thw = img.meta['grid_thw'][0].tolist()
+ _, h, w = grid_thw
+ h //= 2
+ w //= 2
+ num_pad = img.end - img.start - max(h, w)
+ mrope_delta -= num_pad
+ fill_start = img.start - pos_start
+ fill_end = img.end - pos_start
+ img_pos_ids = self._get_multimodal_pos_ids(
+ grid_thw, pos_ids.device)
+ img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1]
+ mrope_pos_ids[:, fill_end:] -= num_pad
+ mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids
+
+ mrope_position_ids.append(mrope_pos_ids)
+ new_model_metas.append(dict(mrope_delta=mrope_delta))
+
+ mrope_position_ids = torch.cat(mrope_position_ids, dim=1)
+ context.mrope_position_ids = mrope_position_ids
+
+ return new_model_metas
+
+ def update_model_metas(self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None):
+ """update model meta."""
+ if context.is_decoding:
+ return self._update_model_meta_decoding(context)
+ else:
+ return self._update_model_meta_prefilling(context)
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return self.input_processor
+
+
+InputMultiModalType = List[Dict[str, Any]]
+
+
+class Qwen2VLInputProcessor(BaseModelInputProcessor):
+ """qwen2 input processor."""
+
+ def __init__(self, config: PretrainedConfig) -> None:
+ self.config = config
+
+ def preprocess_input(self,
+ input_ids: List[int],
+ input_multimodals: List[Dict[str, Any]] = None,
+ **kwargs) -> PreprocessInputResult:
+ """prepare multimodal input."""
+ if input_multimodals is None or len(input_multimodals) == 0:
+ return input_ids, input_multimodals
+
+ input_imgs = []
+ for input_mm in input_multimodals:
+ pixel_values = input_mm['pixel_values']
+ image_grid_thw = input_mm['image_grid_thw']
+ offset = input_mm['offset']
+ start = offset
+ image_token_id = input_mm.get('image_token_id', 0)
+ num_pad = input_mm['image_tokens']
+ if isinstance(num_pad, torch.Tensor):
+ num_pad = num_pad.item()
+
+ mm_data = MultiModalTensor(data=pixel_values,
+ start=start,
+ end=start + num_pad,
+ meta=dict(
+ grid_thw=image_grid_thw,
+ image_token_id=image_token_id))
+ input_imgs.append(mm_data)
+
+ result = PreprocessInputResult(
+ input_ids=input_ids,
+ input_multimodals=dict(image=input_imgs),
+ )
+ return result
diff --git a/lmdeploy/pytorch/models/utils/model.py b/lmdeploy/pytorch/models/utils/model.py
new file mode 100644
index 0000000000..99bd4c4bfb
--- /dev/null
+++ b/lmdeploy/pytorch/models/utils/model.py
@@ -0,0 +1,46 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Iterable, List, Optional, Tuple
+
+import torch
+
+from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor
+from lmdeploy.pytorch.model_inputs import StepContext
+
+
+class DeployModelMixin:
+
+ def forward(self, *args, **kwargs):
+ """forward of model."""
+ raise NotImplementedError('Not Implemented')
+
+ def prepare_inputs_for_generation(
+ self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None,
+ ):
+ """prepare input."""
+ raise NotImplementedError('Not Implemented')
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ """load weights."""
+ raise NotImplementedError('Not Implemented')
+
+ def get_logits(self, hidden_states: torch.Tensor):
+ """compute logits of the model output."""
+ return hidden_states
+
+ def update_weights(self):
+ """update weights."""
+ pass
+
+ def update_model_metas(self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None):
+ """update model meta."""
+ return None
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """get input processor."""
+ return None
diff --git a/lmdeploy/pytorch/models/utils/multimodal.py b/lmdeploy/pytorch/models/utils/multimodal.py
new file mode 100644
index 0000000000..aebcaf4073
--- /dev/null
+++ b/lmdeploy/pytorch/models/utils/multimodal.py
@@ -0,0 +1,14 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Tuple
+
+from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs
+
+PreparedInputs = Tuple[List[int], MultiModalInputs]
+
+
+class MultiModalMixin:
+
+ def prepare_multimodal_input(self, input_ids, input_multimodals,
+ **kwargs) -> PreparedInputs:
+ """prepare multimodals inputs."""
+ raise NotImplementedError('prepare input not implemented.')
diff --git a/lmdeploy/pytorch/multimodal/__init__.py b/lmdeploy/pytorch/multimodal/__init__.py
new file mode 100644
index 0000000000..c3e8c6a16f
--- /dev/null
+++ b/lmdeploy/pytorch/multimodal/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .data_type import MultiModalData, MultiModalTensor
+
+__all__ = ['MultiModalData', 'MultiModalTensor']
diff --git a/lmdeploy/pytorch/multimodal/data_type.py b/lmdeploy/pytorch/multimodal/data_type.py
new file mode 100644
index 0000000000..886c7ffbd0
--- /dev/null
+++ b/lmdeploy/pytorch/multimodal/data_type.py
@@ -0,0 +1,104 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from dataclasses import dataclass, fields
+from typing import Any, Dict, List, Union
+
+import torch
+from torch import Tensor
+from torch import distributed as dist
+
+
+class MultiModalData:
+ pass
+
+
+MultiModalDataList = List[MultiModalData]
+
+NestedTensor = Union[Tensor, List[Tensor]]
+
+
+def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'):
+ """broadcast tensor."""
+ if value.device.type == 'meta':
+ value = torch.empty_like(value, device=device)
+ dist.broadcast(value, src)
+ return value
+
+
+@dataclass
+class MultiModalTensor:
+ data: NestedTensor
+ start: int
+ end: int = None
+ encoder_len: int = None
+ meta: Dict[str, Any] = None
+
+ def __post_init__(self):
+ if self.end is None:
+ self.end = self.start
+
+ def to_device(self, device: str, non_blocking: bool = False):
+ """to device."""
+ out_dict = dict()
+ for f in fields(self):
+ k = f.name
+ if k in ('data', 'meta'):
+ continue
+ v = getattr(self, k)
+ out_dict[k] = v
+
+ if isinstance(self.data, Tensor):
+ data = self.data.to(device=device, non_blocking=non_blocking)
+ else:
+ data = [
+ d.to(device=device, non_blocking=non_blocking)
+ for d in self.data
+ ]
+ out_dict['data'] = data
+
+ new_meta = None
+ if self.meta is not None:
+ new_meta = dict()
+ for k, v in self.meta.items():
+ if isinstance(v, Tensor):
+ v = v.to(device=device, non_blocking=non_blocking)
+ elif hasattr(v, 'to_device'):
+ v = v.to_device(device=device, non_blocking=non_blocking)
+ new_meta[k] = v
+
+ out_dict['meta'] = new_meta
+ return MultiModalTensor(**out_dict)
+
+ def broadcast(self):
+ """broadcast inputs tensors."""
+ out_dict = dict()
+ for f in fields(self):
+ k = f.name
+ if k in ('data', 'meta'):
+ continue
+ v = getattr(self, k)
+ out_dict[k] = v
+
+ if isinstance(self.data, Tensor):
+ data = _broadcast_tensor(self.data)
+ else:
+ data = [_broadcast_tensor(d) for d in self.data]
+ out_dict['data'] = data
+
+ new_meta = None
+ if self.meta is not None:
+ new_meta = dict()
+ for k, v in self.meta.items():
+ if isinstance(v, Tensor):
+ v = _broadcast_tensor(v)
+ self.meta[k] = v
+ elif hasattr(v, 'to_device'):
+ assert hasattr(v, 'broadcast')
+ v = v.broadcast()
+ self.meta[k] = v
+ new_meta[k] = v
+
+ out_dict['meta'] = new_meta
+ return MultiModalTensor(**out_dict)
+
+
+MultiModalInputs = Dict[str, List[MultiModalTensor]]
diff --git a/lmdeploy/pytorch/multimodal/image_type.py b/lmdeploy/pytorch/multimodal/image_type.py
new file mode 100644
index 0000000000..19211a381f
--- /dev/null
+++ b/lmdeploy/pytorch/multimodal/image_type.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from dataclasses import dataclass
+from typing import Any, ClassVar, Dict
+
+from PIL import Image
+
+from .data_type import MultiModalData
+
+
+@dataclass
+class ImageData(MultiModalData):
+ data: Image
+ loc: int
+ meta: Dict[str, Any] = None
+ type: ClassVar[str] = 'image'
diff --git a/lmdeploy/pytorch/nn/__init__.py b/lmdeploy/pytorch/nn/__init__.py
index 63df9a5ae9..4705115bf4 100644
--- a/lmdeploy/pytorch/nn/__init__.py
+++ b/lmdeploy/pytorch/nn/__init__.py
@@ -2,7 +2,7 @@
# attention module is modified from:
# https://github.com/vllm-project/vllm/blob/main/vllm/attention/
from .activation import GeluAndMul, SiluAndMul # noqa: F401
-from .attention import Attention # noqa: F401
+from .attention import Attention, FlashAttention # noqa: F401
from .norm import LayerNorm, RMSNorm # noqa: F401
from .rotary_embedding import ApplyRotaryEmb # noqa: F401
from .rotary_embedding import RopeType # noqa: F401
diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py
index 26f1034d36..684c8122f8 100644
--- a/lmdeploy/pytorch/nn/attention.py
+++ b/lmdeploy/pytorch/nn/attention.py
@@ -9,6 +9,14 @@
from .utils import get_distribute_size
+def _update_num_heads(num_heads: int, num_kv_heads: int):
+ """update heads."""
+ world_size, rank = get_world_rank()
+ num_heads = get_distribute_size(num_heads, world_size, rank)
+ num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank)
+ return num_heads, num_kv_heads
+
+
class Attention(nn.Module):
"""Attention layer."""
@@ -22,15 +30,19 @@ def __init__(
alibi: bool = False,
sliding_window: int = None,
logit_softcapping: float = None,
- replicate_kv: bool = False,
+ causal: bool = True,
**kwargs,
):
super().__init__()
- num_heads, num_kv_heads = self._update_num_heads(
- num_heads, num_kv_heads, replicate_kv)
+ if num_kv_heads is None:
+ num_kv_heads = num_heads
+ if v_head_size is None:
+ v_head_size = head_size
+ num_heads, num_kv_heads = _update_num_heads(num_heads, num_kv_heads)
layer_backend = get_backend()
- impl_builder = layer_backend.get_layer_impl_builder(OpType.Attention)
+ impl_builder = layer_backend.get_layer_impl_builder(
+ OpType.PagedAttention)
self.impl = impl_builder.build(
num_heads=num_heads,
@@ -41,18 +53,10 @@ def __init__(
alibi=alibi,
sliding_window=sliding_window,
logit_softcapping=logit_softcapping,
+ causal=causal,
**kwargs,
)
- def _update_num_heads(self, num_heads: int, num_kv_heads: int,
- replicate_kv: bool):
- """update heads."""
- world_size, rank = get_world_rank()
- num_heads = get_distribute_size(num_heads, world_size, rank)
- if not replicate_kv:
- num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank)
- return num_heads, num_kv_heads
-
def forward(
self,
query: torch.Tensor,
@@ -77,3 +81,75 @@ def forward(
v_scales_zeros=v_scales_zeros,
inplace=inplace,
)
+
+
+class FlashAttention(nn.Module):
+ """flash attention w/o paging."""
+
+ def __init__(
+ self,
+ num_heads: int,
+ head_dim: int,
+ scale: float = None,
+ num_kv_heads: int = None,
+ v_head_dim: int = None,
+ causal: bool = True,
+ sliding_window: int = None,
+ logit_softcapping: float = None,
+ **kwargs,
+ ):
+ super().__init__()
+ if num_kv_heads is None:
+ num_kv_heads = num_heads
+ if v_head_dim is None:
+ v_head_dim = head_dim
+ num_heads, num_kv_heads = _update_num_heads(num_heads, num_kv_heads)
+
+ layer_backend = get_backend()
+
+ impl_builder = layer_backend.get_layer_impl_builder(
+ OpType.FlashAttention)
+
+ self.impl = impl_builder.build(
+ num_heads=num_heads,
+ head_dim=head_dim,
+ scale=scale,
+ num_kv_heads=num_kv_heads,
+ v_head_dim=v_head_dim,
+ causal=causal,
+ sliding_window=sliding_window,
+ logit_softcapping=logit_softcapping,
+ **kwargs,
+ )
+
+ def forward(self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ q_start_loc: torch.Tensor,
+ q_seqlens: torch.Tensor,
+ kv_start_loc: torch.Tensor = None,
+ kv_seqlens: torch.Tensor = None,
+ max_q_seqlen: int = None) -> torch.Tensor:
+ """forward."""
+
+ if max_q_seqlen is None:
+ max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
+
+ if kv_start_loc is None and kv_seqlens is None:
+ kv_start_loc = q_start_loc
+ kv_seqlens = q_seqlens
+
+ assert kv_start_loc is not None
+ assert kv_seqlens is not None
+
+ return self.impl.forward(
+ query,
+ key,
+ value,
+ q_start_loc=q_start_loc,
+ q_seqlens=q_seqlens,
+ kv_start_loc=kv_start_loc,
+ kv_seqlens=kv_seqlens,
+ max_q_seqlen=max_q_seqlen,
+ )
diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py
index 08040ee00c..486c684a3c 100644
--- a/lmdeploy/pytorch/nn/linear.py
+++ b/lmdeploy/pytorch/nn/linear.py
@@ -12,7 +12,7 @@
from ..backends import OpType, get_backend
from ..backends.lora import AdapterInfo
-from .utils import div_up, get_distribute_size
+from .utils import get_distribute_size
logger = get_logger('lmdeploy')
@@ -32,9 +32,11 @@ def _chunk_align(weight: torch.Tensor, chunks: int, dim: int, align: int):
size = weight.size(dim)
assert size % align == 0
aligned_size = size // align
- align_per_chunk = div_up(aligned_size, chunks)
- sections = [align_per_chunk] * (chunks - 1)
- sections += [aligned_size - align_per_chunk * (chunks - 1)]
+
+ # try best to evenly split chunks
+ align_per_chunk = aligned_size // chunks
+ remain = aligned_size % chunks
+ sections = [align_per_chunk + int(c < remain) for c in range(chunks)]
sections = [sec * align for sec in sections]
return weight.split(sections, dim=dim)
@@ -42,20 +44,24 @@ def _chunk_align(weight: torch.Tensor, chunks: int, dim: int, align: int):
class QKVMixin:
"""qkv mixin."""
- def _get_qkv_out_features(self, num_q_heads: int, num_kv_heads: int,
- head_size: int, head_size_v: int):
+ def _get_qkv_out_features(self,
+ num_q_heads: int,
+ num_kv_heads: int,
+ head_size: int,
+ head_size_v: int,
+ num_replicate_kv_heads: int = 1):
"""get io features."""
- all_out_features = (num_q_heads * head_size, num_kv_heads * head_size,
- num_kv_heads * head_size_v)
+ num_kv_heads_real = num_kv_heads // num_replicate_kv_heads
+ all_out_features = (num_q_heads * head_size,
+ num_kv_heads_real * head_size,
+ num_kv_heads_real * head_size_v)
return all_out_features
- def _update_num_heads(self, num_q_heads: int, num_kv_heads: int,
- replicate_kv: bool):
+ def _update_num_heads(self, num_q_heads: int, num_kv_heads: int):
"""update num heads."""
world_size, rank = get_world_rank()
num_q_heads = get_distribute_size(num_q_heads, world_size, rank)
- if not replicate_kv:
- num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank)
+ num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank)
return num_q_heads, num_kv_heads
@@ -212,7 +218,7 @@ def __init__(
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size
- self.elem_per_int = 32 // self.w_bit
+ self.elem_per_int = 32 // w_bit
self.lora_adapters = nn.ModuleDict()
self.is_tp = is_tp
self.colwise = colwise
@@ -363,12 +369,9 @@ def __init__(self,
w_bit: int,
group_size: int,
bias: bool,
- replicate: Optional[List[bool]] = None,
device: Optional[torch.device] = None,
is_tp: bool = True,
out_names: Optional[List[int]] = None):
- if replicate is None:
- replicate = tuple(False for _ in all_out_features)
self.split_section_s = all_out_features
elem_per_int = 32 // w_bit
@@ -377,9 +380,8 @@ def __init__(self,
]
all_out_features = self._update_all_out_features(
- all_out_features, w_bit, group_size, replicate)
+ all_out_features, w_bit, group_size)
self.all_out_features = all_out_features
- self.replicate = replicate
if out_names is None:
out_names = torch.arange(len(self.all_out_features)).tolist()
assert len(out_names) == len(self.all_out_features)
@@ -414,15 +416,12 @@ def _get_io_features(self, in_features: int, out_features: int, w_bit: int,
return in_features, out_features
def _update_all_out_features(self, all_out_features: List[int], w_bit: int,
- group_size: int,
- replicate: Optional[List[bool]]):
+ group_size: int):
"""update all out features."""
world_size, rank = get_world_rank()
new_all_out_features = []
align = max(32 // w_bit, group_size)
- for out_feat, rep in zip(all_out_features, replicate):
- if rep:
- new_all_out_features.append(out_feat)
+ for out_feat in all_out_features:
new_out_feat = get_distribute_size(out_feat, world_size, rank,
align)
new_all_out_features.append(new_out_feat)
@@ -433,14 +432,11 @@ def weight_loader(self, param: torch.nn.Parameter,
"""weight loader."""
world_size, rank = get_world_rank()
shard_idx = self.out_names_map[shard_id]
-
if loaded_weight.dim() == 1:
# bias
align = max(self.elem_per_int, self.group_size)
param_w = param.data.split(self.all_out_features, 0)[shard_idx]
- if not self.replicate[shard_idx]:
- weight = _chunk_align(loaded_weight, world_size, 0,
- align)[rank]
+ weight = _chunk_align(loaded_weight, world_size, 0, align)[rank]
param_w.copy_(weight)
if param._weight_type in ['scales', 'bias']:
@@ -456,8 +452,7 @@ def weight_loader(self, param: torch.nn.Parameter,
]
param_w = param.data.split(quanted_out_feats, 1)[shard_idx]
- if not self.replicate[shard_idx]:
- weight = _chunk_align(loaded_weight, world_size, -1, align)[rank]
+ weight = _chunk_align(loaded_weight, world_size, -1, align)[rank]
param_w.copy_(weight)
def weight_spliter_wz(self, loaded_weight: torch.Tensor):
@@ -480,45 +475,82 @@ def __init__(self,
head_size_v: int,
w_bit: int,
group_size: int,
- replicate_kv: bool = False,
bias: bool = False,
device: Optional[torch.device] = None,
- is_tp: bool = True):
+ is_tp: bool = True,
+ num_replicate_kv_heads: int = 1):
self.qkv_split_section_s = self._get_qkv_out_features(
- num_q_heads, num_kv_heads, head_size, head_size_v)
+ num_q_heads, num_kv_heads, head_size, head_size_v,
+ num_replicate_kv_heads)
elem_per_int = 32 // w_bit
self.qkv_split_section_wz = [
size // elem_per_int for size in self.qkv_split_section_s
]
num_q_heads, num_kv_heads = self._update_num_heads(
- num_q_heads, num_kv_heads, replicate_kv)
+ num_q_heads, num_kv_heads)
all_out_features = self._get_qkv_out_features(num_q_heads,
num_kv_heads, head_size,
head_size_v)
- replicate = (False, replicate_kv, replicate_kv)
out_names = ('q', 'k', 'v')
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_size = head_size
self.head_size_v = head_size_v
+ self.num_replicate_kv_heads = num_replicate_kv_heads
+
super().__init__(in_features,
all_out_features,
w_bit=w_bit,
group_size=group_size,
bias=bias,
- replicate=replicate,
device=device,
is_tp=is_tp,
out_names=out_names)
def _update_all_out_features(self, all_out_features: List[int], w_bit: int,
- group_size: int,
- replicate: Optional[List[bool]]):
+ group_size: int):
"""update all out features."""
return all_out_features
+ def weight_loader(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, shard_id: Any):
+ """weight loader."""
+ world_size, rank = get_world_rank()
+ chunk_size, chunk_idx = world_size, rank
+ shard_idx = self.out_names_map[shard_id]
+
+ if self.num_replicate_kv_heads > 1 and shard_id in ['k', 'v']:
+ # update to duplicate k/v for tp_size > num_kv_heads
+ chunk_size = world_size // self.num_replicate_kv_heads
+ chunk_idx = rank // self.num_replicate_kv_heads
+
+ if loaded_weight.dim() == 1:
+ # bias
+ align = max(self.elem_per_int, self.group_size)
+ param_w = param.data.split(self.all_out_features, 0)[shard_idx]
+ weight = _chunk_align(loaded_weight, chunk_size, 0,
+ align)[chunk_idx]
+ param_w.copy_(weight)
+ return
+
+ if param._weight_type in ['scales', 'bias']:
+ # scales
+ align = max(self.elem_per_int, self.group_size)
+ param_w = param.data.split(self.all_out_features, -1)[shard_idx]
+ else:
+ # qweight or qzeros
+ align = max(self.elem_per_int,
+ self.group_size) // self.elem_per_int
+ quanted_out_feats = [
+ feat // self.elem_per_int for feat in self.all_out_features
+ ]
+ param_w = param.data.split(quanted_out_feats, 1)[shard_idx]
+
+ weight = _chunk_align(loaded_weight, chunk_size, -1, align)[chunk_idx]
+ param_w.copy_(weight)
+
def weight_spliter_wz(self,
loaded_weight: torch.Tensor,
layout: str = 'default'):
@@ -710,18 +742,13 @@ def __init__(self,
in_features: int,
all_out_features: List[int],
bias: bool,
- replicate: Optional[List[bool]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
is_tp: bool = True,
out_names: Optional[List[int]] = None):
- if replicate is None:
- replicate = tuple(False for _ in all_out_features)
self.split_section = all_out_features
- all_out_features = self._update_all_out_features(
- all_out_features, replicate)
+ all_out_features = self._update_all_out_features(all_out_features)
self.all_out_features = all_out_features
- self.replicate = replicate
if out_names is None:
out_names = torch.arange(len(self.all_out_features)).tolist()
assert len(out_names) == len(self.all_out_features)
@@ -748,14 +775,11 @@ def _get_io_features(self, in_features: int, out_features: int,
"""get io features."""
return in_features, out_features
- def _update_all_out_features(self, all_out_features: List[int],
- replicate: Optional[List[bool]]):
+ def _update_all_out_features(self, all_out_features: List[int]):
"""update all out features."""
world_size, rank = get_world_rank()
new_all_out_features = []
- for out_feat, rep in zip(all_out_features, replicate):
- if rep:
- new_all_out_features.append(out_feat)
+ for out_feat in all_out_features:
new_out_feat = get_distribute_size(out_feat, world_size, rank)
new_all_out_features.append(new_out_feat)
return new_all_out_features
@@ -766,8 +790,7 @@ def weight_loader(self, param: torch.nn.Parameter,
world_size, rank = get_world_rank()
shard_idx = self.out_names_map[shard_id]
param_w = param.data.split(self.all_out_features, 0)[shard_idx]
- if not self.replicate[shard_idx]:
- loaded_weight = loaded_weight.chunk(world_size, 0)[rank]
+ loaded_weight = loaded_weight.chunk(world_size, 0)[rank]
param_w.copy_(loaded_weight)
def weight_spliter(self, loaded_weight: torch.Tensor):
@@ -787,38 +810,57 @@ def __init__(self,
num_kv_heads: int,
head_size: int,
head_size_v: int,
- replicate_kv: bool = False,
bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
- is_tp: bool = True):
+ is_tp: bool = True,
+ num_replicate_kv_heads: int = 1):
self.qkv_split_section = self._get_qkv_out_features(
- num_q_heads, num_kv_heads, head_size, head_size_v)
+ num_q_heads, num_kv_heads, head_size, head_size_v,
+ num_replicate_kv_heads)
num_q_heads, num_kv_heads = self._update_num_heads(
- num_q_heads, num_kv_heads, replicate_kv)
+ num_q_heads, num_kv_heads)
all_out_features = self._get_qkv_out_features(num_q_heads,
num_kv_heads, head_size,
head_size_v)
- replicate = (False, replicate_kv, replicate_kv)
out_names = ('q', 'k', 'v')
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_size = head_size
self.head_size_v = head_size_v
+ self.num_replicate_kv_heads = num_replicate_kv_heads
super().__init__(in_features,
all_out_features,
bias=bias,
- replicate=replicate,
dtype=dtype,
device=device,
is_tp=is_tp,
out_names=out_names)
- def _update_all_out_features(self, all_out_features: List[int],
- replicate: Optional[List[bool]]):
+ def _update_all_out_features(self, all_out_features: List[int]):
"""update all out features."""
return all_out_features
+ def weight_loader(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, shard_id: Any):
+ """weight loader."""
+ _, rank = get_world_rank()
+ shard_idx = self.out_names_map[shard_id]
+ param_w = param.data.split(self.all_out_features, 0)[shard_idx]
+ num_head = self.num_q_heads if shard_id == 'q' \
+ else self.num_kv_heads
+ head_dim = self.head_size if shard_id in ['q', 'k'] \
+ else self.head_size_v
+ # update to duplicate k/v for tp_size > num_kv_heads
+ rank_idx = rank if shard_id == 'q' \
+ else rank // self.num_replicate_kv_heads
+ sec_start = rank_idx * num_head * head_dim
+ sec_len = num_head * head_dim
+ loaded_weight = loaded_weight.narrow(dim=0,
+ start=sec_start,
+ length=sec_len)
+ param_w.copy_(loaded_weight)
+
def weight_spliter(self,
loaded_weight: torch.Tensor,
layout: str = 'default'):
@@ -986,18 +1028,13 @@ def __init__(self,
in_features: int,
all_out_features: List[int],
bias: bool,
- replicate: Optional[List[bool]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
is_tp: bool = True,
out_names: Optional[List[int]] = None):
- if replicate is None:
- replicate = tuple(False for _ in all_out_features)
self.split_section = all_out_features
- all_out_features = self._update_all_out_features(
- all_out_features, replicate)
+ all_out_features = self._update_all_out_features(all_out_features)
self.all_out_features = all_out_features
- self.replicate = replicate
if out_names is None:
out_names = torch.arange(len(self.all_out_features)).tolist()
assert len(out_names) == len(self.all_out_features)
@@ -1022,14 +1059,11 @@ def _get_io_features(self, in_features: int, out_features: int,
"""get io features."""
return in_features, out_features
- def _update_all_out_features(self, all_out_features: List[int],
- replicate: Optional[List[bool]]):
+ def _update_all_out_features(self, all_out_features: List[int]):
"""update all out features."""
world_size, rank = get_world_rank()
new_all_out_features = []
- for out_feat, rep in zip(all_out_features, replicate):
- if rep:
- new_all_out_features.append(out_feat)
+ for out_feat in all_out_features:
new_out_feat = get_distribute_size(out_feat, world_size, rank)
new_all_out_features.append(new_out_feat)
return new_all_out_features
@@ -1040,8 +1074,7 @@ def weight_loader(self, param: torch.nn.Parameter,
world_size, rank = get_world_rank()
shard_idx = self.out_names_map[shard_id]
param_w = param.data.split(self.all_out_features, 0)[shard_idx]
- if not self.replicate[shard_idx]:
- loaded_weight = loaded_weight.chunk(world_size, 0)[rank]
+ loaded_weight = loaded_weight.chunk(world_size, 0)[rank]
param_w.copy_(loaded_weight)
def weight_spliter(self, loaded_weight: torch.Tensor):
@@ -1061,35 +1094,36 @@ def __init__(self,
num_kv_heads: int,
head_size: int,
head_size_v: int,
- replicate_kv: bool = False,
bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
- is_tp: bool = True):
+ is_tp: bool = True,
+ num_replicate_kv_heads: int = 1):
+
self.qkv_split_section = self._get_qkv_out_features(
- num_q_heads, num_kv_heads, head_size, head_size_v)
+ num_q_heads, num_kv_heads, head_size, head_size_v,
+ num_replicate_kv_heads)
num_q_heads, num_kv_heads = self._update_num_heads(
- num_q_heads, num_kv_heads, replicate_kv)
+ num_q_heads, num_kv_heads)
all_out_features = self._get_qkv_out_features(num_q_heads,
num_kv_heads, head_size,
head_size_v)
- replicate = (False, replicate_kv, replicate_kv)
out_names = ('q', 'k', 'v')
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_size = head_size
self.head_size_v = head_size_v
+ self.num_replicate_kv_heads = num_replicate_kv_heads
+
super().__init__(in_features,
all_out_features,
bias=bias,
- replicate=replicate,
dtype=dtype,
device=device,
is_tp=is_tp,
out_names=out_names)
- def _update_all_out_features(self, all_out_features: List[int],
- replicate: Optional[List[bool]]):
+ def _update_all_out_features(self, all_out_features: List[int]):
"""update all out features."""
return all_out_features
@@ -1097,15 +1131,20 @@ def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, shard_id: Any):
"""weight loader."""
world_size, rank = get_world_rank()
+ chunk_size, chunk_idx = world_size, rank
shard_idx = self.out_names_map[shard_id]
param_w = param.data.split(self.all_out_features, 0)[shard_idx]
- if not self.replicate[shard_idx]:
- if shard_idx in [0, 1]:
- loaded_weight = _chunk_align(loaded_weight, world_size, 0,
- self.head_size)[rank]
- if shard_idx == 2:
- loaded_weight = _chunk_align(loaded_weight, world_size, 0,
- self.head_size_v)[rank]
+
+ if self.num_replicate_kv_heads > 1 and shard_id in ['k', 'v']:
+ # update to duplicate k/v for tp_size > num_kv_heads
+ chunk_size = world_size // self.num_replicate_kv_heads
+ chunk_idx = rank // self.num_replicate_kv_heads
+ if shard_idx in [0, 1]:
+ loaded_weight = _chunk_align(loaded_weight, chunk_size, 0,
+ self.head_size)[chunk_idx]
+ elif shard_idx == 2:
+ loaded_weight = _chunk_align(loaded_weight, chunk_size, 0,
+ self.head_size_v)[chunk_idx]
param_w.copy_(loaded_weight)
def weight_spliter(self,
@@ -1291,12 +1330,12 @@ def build_qkv_proj(in_features: int,
num_kv_heads: int,
head_size: int,
head_size_v: int = None,
- replicate_kv: bool = False,
bias: bool = False,
quant_config: Any = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
- is_tp: bool = True):
+ is_tp: bool = True,
+ num_replicate_kv_heads: int = 1):
"""build qkv proj."""
if is_tp:
world_size, _ = get_world_rank()
@@ -1306,48 +1345,42 @@ def build_qkv_proj(in_features: int,
head_size_v = head_size
if quant_config is None:
- return QKVBaseLinear(
- in_features=in_features,
- num_q_heads=num_q_heads,
- num_kv_heads=num_kv_heads,
- head_size=head_size,
- head_size_v=head_size_v,
- replicate_kv=replicate_kv,
- bias=bias,
- dtype=dtype,
- device=device,
- is_tp=is_tp,
- )
+ return QKVBaseLinear(in_features=in_features,
+ num_q_heads=num_q_heads,
+ num_kv_heads=num_kv_heads,
+ head_size=head_size,
+ head_size_v=head_size_v,
+ bias=bias,
+ dtype=dtype,
+ device=device,
+ is_tp=is_tp,
+ num_replicate_kv_heads=num_replicate_kv_heads)
quant_method = quant_config['quant_method']
if quant_method == 'awq':
w_bit = quant_config.get('bits', 4)
group_size = quant_config.get('group_size', 128)
- return QKVAwqLinear(
- in_features=in_features,
- num_q_heads=num_q_heads,
- num_kv_heads=num_kv_heads,
- head_size=head_size,
- head_size_v=head_size_v,
- replicate_kv=replicate_kv,
- w_bit=w_bit,
- group_size=group_size,
- bias=bias,
- device=device,
- is_tp=is_tp,
- )
+ return QKVAwqLinear(in_features=in_features,
+ num_q_heads=num_q_heads,
+ num_kv_heads=num_kv_heads,
+ head_size=head_size,
+ head_size_v=head_size_v,
+ w_bit=w_bit,
+ group_size=group_size,
+ bias=bias,
+ device=device,
+ is_tp=is_tp,
+ num_replicate_kv_heads=num_replicate_kv_heads)
if quant_method == 'smooth_quant':
- return QKVW8A8Linear(
- in_features=in_features,
- num_q_heads=num_q_heads,
- num_kv_heads=num_kv_heads,
- head_size=head_size,
- head_size_v=head_size_v,
- replicate_kv=replicate_kv,
- bias=bias,
- dtype=dtype,
- device=device,
- is_tp=is_tp,
- )
+ return QKVW8A8Linear(in_features=in_features,
+ num_q_heads=num_q_heads,
+ num_kv_heads=num_kv_heads,
+ head_size=head_size,
+ head_size_v=head_size_v,
+ bias=bias,
+ dtype=dtype,
+ device=device,
+ is_tp=is_tp,
+ num_replicate_kv_heads=num_replicate_kv_heads)
else:
raise RuntimeError(f'Unsupported quant method: {quant_method}')
diff --git a/lmdeploy/pytorch/nn/utils.py b/lmdeploy/pytorch/nn/utils.py
index 3289f858a7..3b60ca21de 100644
--- a/lmdeploy/pytorch/nn/utils.py
+++ b/lmdeploy/pytorch/nn/utils.py
@@ -11,7 +11,10 @@ def get_distribute_size(feature_size: int,
"""update feature size."""
assert feature_size % align == 0
aligned_size = feature_size // align
- align_per_rank = div_up(aligned_size, world_size)
- prev_feats = align_per_rank * rank
- updated_aligned_size = min(align_per_rank, aligned_size - prev_feats)
+ # try to make every rank has same amount of feats
+ updated_aligned_size = aligned_size // world_size
+ # if there are still some remain, given them to
+ # each rank
+ if rank < aligned_size % world_size:
+ updated_aligned_size += 1
return updated_aligned_size * align
diff --git a/lmdeploy/pytorch/supported_models.py b/lmdeploy/pytorch/supported_models.py
index 7fa568651b..67452f78e3 100644
--- a/lmdeploy/pytorch/supported_models.py
+++ b/lmdeploy/pytorch/supported_models.py
@@ -47,9 +47,9 @@
# cogvlm-chat
CogVLMForCausalLM=True,
# llava
- LlavaLlamaForCausalLM=True,
+ LlavaLlamaForCausalLM=False,
# llava mistral
- LlavaMistralForCausalLM=True,
+ LlavaMistralForCausalLM=False,
# deepseekvl
MultiModalityCausalLM=False,
# StarCoder2
diff --git a/lmdeploy/pytorch/tools/make_inputs.py b/lmdeploy/pytorch/tools/make_inputs.py
index f2d23830b7..053e7d0918 100644
--- a/lmdeploy/pytorch/tools/make_inputs.py
+++ b/lmdeploy/pytorch/tools/make_inputs.py
@@ -135,6 +135,7 @@ def __fill_kv_caches(kv_caches, past_key_values, block_offsets):
return StepContext.new(
inputs=model_inputs,
+ model_config=model_config,
world_size=world_size,
kv_caches=kv_caches,
)
diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py
index e9083fc110..0e1a01f4f4 100644
--- a/lmdeploy/serve/async_engine.py
+++ b/lmdeploy/serve/async_engine.py
@@ -325,9 +325,10 @@ def __call__(self,
"""Inference a batch of prompts.
Args:
- prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
- prompts. It accepts: string prompt, a list of string prompts,
- a chat history in OpenAI format or a list of chat history.
+ prompts (List[str] | str | List[Dict] | List[List[Dict]]]): a
+ batch of prompts. It accepts: string prompt, a list of string
+ prompts, a chat history in OpenAI format or a list of chat
+ history.
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
@@ -471,9 +472,10 @@ def batch_infer(self,
"""Inference a batch of prompts.
Args:
- prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
- prompts. It accepts: string prompt, a list of string prompts,
- a chat history in OpenAI format or a list of chat history.
+ prompts (List[str] | str | List[Dict] | List[List[Dict]]]): a
+ batch of prompts. It accepts: string prompt, a list of string
+ prompts, a chat history in OpenAI format or a list of chat
+ history.
gen_config (GenerationConfig | None): a instance of or a list of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
@@ -516,9 +518,10 @@ def stream_infer(
"""Inference a batch of prompts with stream mode.
Args:
- prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
- prompts. It accepts: string prompt, a list of string prompts,
- a chat history in OpenAI format or a list of chat history.
+ prompts (List[str] | str | List[Dict] | List[List[Dict]]]):a
+ batch of prompts. It accepts: string prompt, a list of string
+ prompts, a chat history in OpenAI format or a list of chat
+ history.
gen_config (GenerationConfig | None): a instance of or a list of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
@@ -622,8 +625,8 @@ async def generate(
if gen_config.stop_token_ids is None:
gen_config.stop_token_ids = self.stop_words
if not gen_config.do_sample:
- logger.warn(f'GenerationConfig: {gen_config}')
- logger.warn(
+ logger.warning(f'GenerationConfig: {gen_config}')
+ logger.warning(
'Since v0.6.0, lmdeploy add `do_sample` in '
'GenerationConfig. It defaults to False, meaning greedy '
'decoding. Please set `do_sample=True` if sampling '
diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py
index 2d0560720d..cce9567896 100644
--- a/lmdeploy/serve/openai/api_server.py
+++ b/lmdeploy/serve/openai/api_server.py
@@ -946,6 +946,20 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
return JSONResponse(ret)
+def handle_torchrun():
+ """To disable mmengine logging logic when using torchrun."""
+
+ def dummy_get_device_id():
+ return 0
+
+ if int(os.environ.get('LOCAL_RANK', -1)) > 0:
+ from lmdeploy.vl.model.utils import _set_func
+
+ # the replacement can't be recovered
+ _set_func('mmengine.logging.logger._get_device_id',
+ dummy_get_device_id)
+
+
@router.on_event('startup')
async def startup_event():
if VariableInterface.proxy_url is None:
@@ -1069,8 +1083,8 @@ def serve(model_path: str,
ssl_certfile = os.environ['SSL_CERTFILE']
http_or_https = 'https'
+ handle_torchrun()
_, pipeline_class = get_task(model_path)
-
VariableInterface.async_engine = pipeline_class(
model_path=model_path,
model_name=model_name,
diff --git a/lmdeploy/serve/proxy/constants.py b/lmdeploy/serve/proxy/constants.py
index 88d86a3e33..5bf6e67659 100644
--- a/lmdeploy/serve/proxy/constants.py
+++ b/lmdeploy/serve/proxy/constants.py
@@ -2,8 +2,8 @@
import enum
-LATENCY_DEEQUE_LEN = 15
-API_TIMEOUT_LEN = 100
+LATENCY_DEQUE_LEN = 15
+API_READ_TIMEOUT = 100
class Strategy(enum.Enum):
diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py
index 5f05930bd0..392ede3267 100644
--- a/lmdeploy/serve/proxy/proxy.py
+++ b/lmdeploy/serve/proxy/proxy.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import asyncio
import copy
import json
import os
@@ -18,14 +19,15 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
+from requests.exceptions import RequestException
from lmdeploy.serve.openai.api_server import (check_api_key,
create_error_response)
from lmdeploy.serve.openai.protocol import ( # noqa: E501
ChatCompletionRequest, CompletionRequest, ModelCard, ModelList,
ModelPermission)
-from lmdeploy.serve.proxy.constants import (API_TIMEOUT_LEN,
- LATENCY_DEEQUE_LEN, ErrorCodes,
+from lmdeploy.serve.proxy.constants import (API_READ_TIMEOUT,
+ LATENCY_DEQUE_LEN, ErrorCodes,
Strategy, err_msg)
from lmdeploy.utils import get_logger
@@ -36,7 +38,7 @@ class Status(BaseModel):
"""Status protocol consists of models' information."""
models: Optional[List[str]] = Field(default=[], examples=[[]])
unfinished: int = 0
- latency: Deque = Field(default=deque(maxlen=LATENCY_DEEQUE_LEN),
+ latency: Deque = Field(default=deque(maxlen=LATENCY_DEQUE_LEN),
examples=[[]])
speed: Optional[int] = Field(default=None, examples=[None])
@@ -87,6 +89,9 @@ def __init__(self,
with open(self.config_path, 'r') as config_file:
self.nodes = yaml.safe_load(config_file)['nodes']
for url, status in self.nodes.items():
+ latency = deque(status.get('latency', []),
+ maxlen=LATENCY_DEQUE_LEN)
+ status['latency'] = latency
status = Status(**status)
self.nodes[url] = status
self.heart_beat_thread = threading.Thread(target=heart_beat_controller,
@@ -99,7 +104,7 @@ def update_config_file(self):
nodes = copy.deepcopy(self.nodes)
for url, status in nodes.items():
nodes[url] = status.model_dump()
- nodes[url]['latency'] = list(status.latency)
+ nodes[url]['latency'] = list(status.latency)[-LATENCY_DEQUE_LEN:]
with open(self.config_path, 'w') as config_file: # update cfg yml
yaml.dump(dict(nodes=nodes), config_file)
@@ -149,7 +154,8 @@ def remove_stale_nodes_by_expiration(self):
to_be_deleted.append(node_url)
for node_url in to_be_deleted:
self.remove(node_url)
- logger.info(f'Removed node_url: {node_url}')
+ logger.info(f'Removed node_url: {node_url} '
+ 'due to heart beat expiration')
@property
def model_list(self):
@@ -251,7 +257,7 @@ def handle_unavailable_model(self, model_name):
Args:
model_name (str): the model in the request.
"""
- logger.info(f'no model name: {model_name}')
+ logger.warning(f'no model name: {model_name}')
ret = {
'error_code': ErrorCodes.MODEL_NOT_FOUND,
'text': err_msg[ErrorCodes.MODEL_NOT_FOUND],
@@ -260,51 +266,54 @@ def handle_unavailable_model(self, model_name):
def handle_api_timeout(self, node_url):
"""Handle the api time out."""
- logger.info(f'api timeout: {node_url}')
+ logger.warning(f'api timeout: {node_url}')
ret = {
- 'error_code': ErrorCodes.API_TIMEOUT,
+ 'error_code': ErrorCodes.API_TIMEOUT.value,
'text': err_msg[ErrorCodes.API_TIMEOUT],
}
return json.dumps(ret).encode() + b'\n'
- def stream_generate(self, request: Dict, node_url: str, node_path: str):
+ def stream_generate(self, request: Dict, node_url: str, endpoint: str):
"""Return a generator to handle the input request.
Args:
request (Dict): the input request.
node_url (str): the node url.
- node_path (str): the node path. Such as `/v1/chat/completions`.
+ endpoint (str): the endpoint. Such as `/v1/chat/completions`.
"""
try:
response = requests.post(
- node_url + node_path,
+ node_url + endpoint,
json=request,
- stream=request['stream'],
- timeout=API_TIMEOUT_LEN,
+ stream=True,
+ timeout=(5, API_READ_TIMEOUT),
)
for chunk in response.iter_lines(decode_unicode=False,
delimiter=b'\n'):
if chunk:
yield chunk + b'\n\n'
- except requests.exceptions.RequestException as e: # noqa
+ except (Exception, GeneratorExit, RequestException) as e: # noqa
+ logger.error(f'catched an exception: {e}')
+ # exception happened, reduce unfinished num
yield self.handle_api_timeout(node_url)
- async def generate(self, request: Dict, node_url: str, node_path: str):
+ async def generate(self, request: Dict, node_url: str, endpoint: str):
"""Return a the response of the input request.
Args:
request (Dict): the input request.
node_url (str): the node url.
- node_path (str): the node path. Such as `/v1/chat/completions`.
+ endpoint (str): the endpoint. Such as `/v1/chat/completions`.
"""
try:
import httpx
async with httpx.AsyncClient() as client:
- response = await client.post(node_url + node_path,
+ response = await client.post(node_url + endpoint,
json=request,
- timeout=API_TIMEOUT_LEN)
+ timeout=API_READ_TIMEOUT)
return response.text
- except requests.exceptions.RequestException as e: # noqa
+ except (Exception, GeneratorExit, RequestException, asyncio.CancelledError) as e: # noqa # yapf: disable
+ logger.error(f'catched an exception: {e}')
return self.handle_api_timeout(node_url)
def pre_call(self, node_url):
@@ -381,7 +390,11 @@ def add_node(node: Node, raw_request: Request = None):
RPM or other metric. All the values of nodes should be the same metric.
"""
try:
- node_manager.add(node.url, node.status)
+ res = node_manager.add(node.url, node.status)
+ if res is not None:
+ logger.error(f'add node {node.url} failed, {res}')
+ return res
+ logger.info(f'add node {node.url} successfully')
return 'Added successfully'
except: # noqa
return 'Failed to add, please check the input url.'
@@ -392,8 +405,10 @@ def remove_node(node_url: str):
"""Show available models."""
try:
node_manager.remove(node_url)
+ logger.info(f'delete node {node_url} successfully')
return 'Deleted successfully'
except: # noqa
+ logger.error(f'delete node {node_url} failed.')
return 'Failed to delete, please check the input url.'
@@ -407,28 +422,50 @@ async def chat_completions_v1(request: ChatCompletionRequest,
The request should be a JSON object with the following fields:
- model: model name. Available from /v1/models.
- - messages: string prompt or chat history in OpenAI format. A example
- for chat history is `[{"role": "user", "content":"knock knock"}]`.
+ - messages: string prompt or chat history in OpenAI format. Chat history
+ example: `[{"role": "user", "content": "hi"}]`.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
- message. Only support one here.
+ message. **Only support one here**.
- stream: whether to stream the results or not. Default to false.
- - max_tokens (int): output token nums
+ - max_tokens (int | None): output token nums. Default to None.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
+ - response_format (Dict | None): Only pytorch backend support formatting
+ response. Examples: `{"type": "json_schema", "json_schema": {"name":
+ "test","schema": {"properties": {"name": {"type": "string"}},
+ "required": ["name"], "type": "object"}}}`
+ or `{"type": "regex_schema", "regex_schema": "call me [A-Za-z]{1,10}"}`
+ - logit_bias (Dict): Bias to logits. Only supported in pytorch engine.
+ - tools (List): A list of tools the model may call. Currently, only
+ internlm2 functions are supported as a tool. Use this to specify a
+ list of functions for which the model can generate JSON inputs.
+ - tool_choice (str | object): Controls which (if any) tool is called by
+ the model. `none` means the model will not call any tool and instead
+ generates a message. Specifying a particular tool via {"type":
+ "function", "function": {"name": "my_function"}} forces the model to
+ call that tool. `auto` or `required` will put all the tools information
+ to the model.
Additional arguments supported by LMDeploy:
+ - top_k (int): The number of the highest probability vocabulary
+ tokens to keep for top-k-filtering
- ignore_eos (bool): indicator for ignoring eos
- - session_id (int): if not specified, will set random value
+ - skip_special_tokens (bool): Whether or not to remove special tokens
+ in the decoding. Default to be True.
+ - min_new_tokens (int): To generate at least numbers of tokens.
+ - min_p (float): Minimum token probability, which will be scaled by the
+ probability of the most likely token. It must be a value between
+ 0 and 1. Typical values are in the 0.01-0.2 range, comparably
+ selective as setting `top_p` in the 0.99-0.8 range (use the
+ opposite of normal `top_p` values)
Currently we do not support the following features:
- - function_call (Users should implement this by themselves)
- - logit_bias (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
@@ -439,6 +476,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,
if not node_url:
return node_manager.handle_unavailable_model(request.model)
+ logger.info(f'A request is dispatched to {node_url}')
request_dict = request.model_dump()
start = node_manager.pre_call(node_url)
if request.stream is True:
@@ -465,13 +503,13 @@ async def completions_v1(request: CompletionRequest,
- model (str): model name. Available from /v1/models.
- prompt (str): the input prompt.
- suffix (str): The suffix that comes after a completion of inserted text.
- - max_tokens (int): output token nums
+ - max_tokens (int): output token nums. Default to 16.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
- message. Only support one here.
+ message. **Only support one here**.
- stream: whether to stream the results or not. Default to false.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
@@ -481,7 +519,8 @@ async def completions_v1(request: CompletionRequest,
Additional arguments supported by LMDeploy:
- ignore_eos (bool): indicator for ignoring eos
- - session_id (int): if not specified, will set random value
+ - skip_special_tokens (bool): Whether or not to remove special tokens
+ in the decoding. Default to be True.
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
@@ -497,6 +536,7 @@ async def completions_v1(request: CompletionRequest,
if not node_url:
return node_manager.handle_unavailable_model(request.model)
+ logger.info(f'A request is dispatched to {node_url}')
request_dict = request.model_dump()
start = node_manager.pre_call(node_url)
if request.stream is True:
@@ -517,6 +557,7 @@ def proxy(server_name: str = '0.0.0.0',
'min_observed_latency'] = 'min_expected_latency',
api_keys: Optional[Union[List[str], str]] = None,
ssl: bool = False,
+ log_level: str = 'INFO',
**kwargs):
"""To launch the proxy server.
@@ -540,6 +581,7 @@ def proxy(server_name: str = '0.0.0.0',
if ssl:
ssl_keyfile = os.environ['SSL_KEYFILE']
ssl_certfile = os.environ['SSL_CERTFILE']
+ logger.setLevel(log_level)
uvicorn.run(app=app,
host=server_name,
port=server_port,
diff --git a/lmdeploy/serve/vl_async_engine.py b/lmdeploy/serve/vl_async_engine.py
index c293cd71c8..becf1b76fb 100644
--- a/lmdeploy/serve/vl_async_engine.py
+++ b/lmdeploy/serve/vl_async_engine.py
@@ -1,148 +1,208 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Dict, List, Optional, Union
+import asyncio
+from typing import Dict, List, Literal, Optional, Tuple, Union
-import numpy as np
+import PIL
+from lmdeploy.messages import (PytorchEngineConfig, TurbomindEngineConfig,
+ VisionConfig)
from lmdeploy.pytorch.check_env import try_import_deeplink
from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.utils import get_logger
-from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX, IMAGE_TOKEN
from lmdeploy.vl.engine import ImageEncoder
-from lmdeploy.vl.templates import VLPromptType, get_vl_prompt_template
+from lmdeploy.vl.utils import load_image
logger = get_logger('lmdeploy')
+VLPromptType = Union[str, Tuple[str, PIL.Image.Image],
+ Tuple[str, List[PIL.Image.Image]]]
+
class VLAsyncEngine(AsyncEngine):
"""Visual Language Async inference engine."""
- def __init__(self, model_path: str, **kwargs) -> None:
- vision_config = kwargs.pop('vision_config', None)
- backend_config = kwargs.get('backend_config', None)
- if kwargs.get('backend', '') == 'pytorch':
+ def __init__(self,
+ model_path: str,
+ backend: Literal['turbomind', 'pytorch'] = 'turbomind',
+ backend_config: Optional[Union[TurbomindEngineConfig,
+ PytorchEngineConfig]] = None,
+ vision_config: Optional[VisionConfig] = None,
+ **kwargs) -> None:
+ if backend == 'pytorch':
try_import_deeplink(backend_config.device_type)
self.vl_encoder = ImageEncoder(model_path,
+ backend,
vision_config,
backend_config=backend_config)
- super().__init__(model_path, **kwargs)
+ super().__init__(model_path,
+ backend=backend,
+ backend_config=backend_config,
+ **kwargs)
if self.model_name == 'base':
raise RuntimeError(
'please specify chat template as guided in https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html#set-chat-template' # noqa: E501
)
- self.vl_prompt_template = get_vl_prompt_template(
- model_path, self.chat_template, self.model_name)
- def _convert_prompts(self,
+ @classmethod
+ def _convert_prompts(cls,
prompts: Union[VLPromptType, List[Dict],
List[VLPromptType], List[List[Dict]]]):
- """convert prompts to openai format."""
+ """convert prompts to openai GPT4V format."""
if isinstance(prompts, str) or isinstance(prompts, tuple):
- _prompts = self.vl_prompt_template.prompt_to_messages(prompts)
+ _prompts = cls.prompt_to_messages(prompts)
elif isinstance(prompts[0], tuple) or isinstance(prompts[0], str):
- _prompts = [
- self.vl_prompt_template.prompt_to_messages(x) for x in prompts
- ]
+ _prompts = [cls.prompt_to_messages(x) for x in prompts]
else:
_prompts = prompts
return _prompts
async def _get_prompt_input(self,
- prompt: Dict,
+ messages: Union[str, List[Dict]],
do_preprocess: bool,
sequence_start: bool,
adapter_name: str,
tools: Optional[List[object]] = None,
**kwargs):
- """get input_ids, embeddings and offsets."""
- if do_preprocess:
- decorated = self.vl_prompt_template.messages2prompt(
- prompt, sequence_start)
- else:
- decorated = prompt
- segs = decorated.split(IMAGE_TOKEN)
-
- results = {}
- input_ids = []
- from lmdeploy.vl.templates import (MllamaTempateWrapper,
- MolmoChatTemplateWrapper,
- Qwen2VLChatTemplateWrapper)
- ranges = None
- grid_thws = None
- if len(segs) > 1:
- # yapf: disable
- images_with_kwargs = await self.vl_prompt_template.async_collect_pil_images(prompt) # noqa: E501
- # yapf: enable
- features = []
- if len(images_with_kwargs) > 0:
- images, image_kwargs = list(zip(*images_with_kwargs))
- features = await self.vl_encoder.async_infer(
- images, image_kwargs)
-
- from lmdeploy.vl.templates import MiniCPMVTempateWrapper
- if isinstance(self.vl_prompt_template, MiniCPMVTempateWrapper):
- decorated, features = self.vl_prompt_template.update_image_token( # noqa: E501
- decorated, features)
- segs = decorated.split(IMAGE_TOKEN)
-
- if isinstance(self.vl_prompt_template,
- Qwen2VLChatTemplateWrapper):
- grid_thws = [x['grid_thw'] for x in features]
- features = [x['embeddings'] for x in features]
-
- if isinstance(self.vl_prompt_template, MllamaTempateWrapper):
- # llama3.2 just encode <|image|> and inference
- decorated = decorated.replace(IMAGE_TOKEN, '<|image|>')
- input_ids = self.tokenizer.encode(decorated,
- add_bos=sequence_start)
- results['input_ids'] = input_ids
- results['prompt'] = decorated
- assert len(features)
- results['cross_attention_states'] = features[0]
- return results
-
- if isinstance(self.vl_prompt_template,
- MolmoChatTemplateWrapper):
- return features[0]
-
- features = [x.cpu().numpy() for x in features]
- input_ids = []
- begins = []
- ends = []
- if len(segs) != len(features) + 1:
- logger.error(
- f'the number of {IMAGE_TOKEN} is not equal '
- f'to input images, {len(segs) - 1} vs {len(features)}')
- features = features[:len(segs) - 1]
- for i, seg in enumerate(segs):
- if i > 0 and i <= len(features):
- image_dim = features[i - 1].shape[0]
- begins.append(len(input_ids))
- ends.append(begins[-1] + image_dim)
- input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim)
- seg_ids = self.tokenizer.encode(seg,
- add_bos=((i == 0)
- and sequence_start))
- input_ids.extend(seg_ids)
- ranges = np.stack([begins, ends], axis=1).tolist()
- results['input_embeddings'] = features or None
- results['input_embedding_ranges'] = ranges or None
+ """process messages and return the required data for the inference
+ engines.
+
+ Refer to pytorch.engine.EngineInstance.async_stream_infer and
+ turbomind.TurboMindInstance.async_stream_infer for the argument
+ specification.
+ """
+ if isinstance(messages, str):
+ return await super()._get_prompt_input(messages, do_preprocess,
+ sequence_start,
+ adapter_name, tools,
+ **kwargs)
+ elif isinstance(messages, List):
+ has_multimodal_input = any(
+ isinstance(message['content'], list) and any(
+ item['type'] in ['image_url', 'image_data']
+ for item in message['content']) for message in messages)
+ if not has_multimodal_input:
+ return await super()._get_prompt_input(messages, do_preprocess,
+ sequence_start,
+ adapter_name, tools,
+ **kwargs)
else:
- input_ids = self.tokenizer.encode(decorated,
- add_bos=sequence_start)
-
- if isinstance(self.vl_prompt_template, Qwen2VLChatTemplateWrapper):
- # TODO: refactor _get_prompt_input function
- mrope_position_ids, mrope_position_delta = \
- self.vl_prompt_template.get_mrope_info(
- len(input_ids), grid_thws=grid_thws,
- embedding_ranges=ranges)
- results['mrope_position_ids'] = mrope_position_ids
- results['mrope_position_delta'] = mrope_position_delta
-
- results['input_ids'] = input_ids
- results['prompt'] = decorated
+ raise RuntimeError(f'unsupported messages {messages}')
+
+ messages = await self.async_convert_to_pil_images(messages)
+ results = await self.vl_encoder.preprocess(messages)
+ if self.backend == 'turbomind':
+ # for tm engine, this module perform vision embedding after image
+ # preprocessing. It utilizes the hf model's vision embeddings
+ # functions and returns the input_ids, input_embeddings,
+ # embedding_ranges and so on. All the returned values are passed
+ # to tm engine for token generation
+ results = await self.vl_encoder.async_infer(results)
+ results = await self.vl_encoder.wrap_for_turbomind(
+ results, self.chat_template, self.tokenizer, sequence_start)
+ elif self.backend == 'pytorch':
+ # for pt engine, this module only conduct the image preprocessing
+ # It leaves the vision embedding to the pt engine
+ results = await self.vl_encoder.wrap_for_pytorch(
+ results, self.chat_template, self.tokenizer, sequence_start)
return results
+ @classmethod
+ async def async_convert_to_pil_images(cls,
+ messages: List[Dict]) -> List[Dict]:
+ """Scan the provided messages to find image URLs or base64-encoded
+ image data. Loads the images into Pillow image objects.
+
+ Args:
+ messages (List[Dict]): a user request of GPT4V message format
+ """
+ if isinstance(messages, Dict):
+ messages = [messages]
+ assert isinstance(messages, List)
+
+ out_messages = [None] * len(messages)
+
+ def _inner_call(i, in_messages, out_messages):
+ role = in_messages[i]['role']
+ content = in_messages[i]['content']
+ assert role in ['system', 'user', 'assistant'], \
+ f'unsupported role "{role}"'
+ if role != 'user' or isinstance(content, str):
+ # the content is a user's prompt or an assistant's prompt,
+ # returning it directly
+ out_messages[i] = in_messages[i]
+ return
+ # the role is a user and the content is a list, in which there
+ # might be image_url or image_data
+ assert isinstance(content, List)
+ message = dict(role=role, content=[])
+ for item in content:
+ # image url or base64-encoded image data
+ if item['type'] == 'image_url':
+ """
+ convert the following item:
+ {
+ 'type': 'image_url',
+ 'image_url': {
+ 'url': 'image url or base64-encoded image data',
+ 'key': 'value' # parameters used in image processing
+ ...
+ }
+ }
+ to:
+ {
+ 'type': 'image',
+ 'image': Pillow.Image,
+ 'key': 'value' # parameters used in image processing
+ ...
+ }
+ """ # noqa
+ data = item['image_url'].copy()
+ try:
+ url = data.pop('url')
+ image = load_image(url)
+ data.update(type='image', image=image)
+ message['content'].append(data)
+ except KeyError:
+ logger.error(f'invalid format {message}')
+ elif item['type'] == 'image_data':
+ """
+ convert the following item:
+ {
+ 'type': 'image_data',
+ 'image_data': {
+ 'data': Pillow.Image,
+ 'key': 'value' # parameters used in image processing
+ ...
+ }
+ }
+ to:
+ {
+ 'type': 'image',
+ 'image': Pillow.Image,
+ 'key': 'value' # parameters used in image processing
+ ...
+ }
+ """ # noqa
+ data = item['image_data'].copy()
+ try:
+ image = data.pop('data')
+ data.update(type='image', image=image)
+ message['content'].append(data)
+ except KeyError:
+ logger.error(f'invalid format {message}')
+ elif item['type'] == 'text':
+ message['content'].append(item)
+ else:
+ logger.error(f'unexpected content type {message}')
+ out_messages[i] = message
+
+ await asyncio.gather(*[
+ asyncio.get_event_loop().run_in_executor(None, _inner_call, i,
+ messages, out_messages)
+ for i in range(len(messages))
+ ])
+ return out_messages
+
def batch_infer(self, prompts: Union[VLPromptType, List[Dict],
List[VLPromptType], List[List[Dict]]],
**kwargs):
@@ -173,3 +233,46 @@ def chat(self, prompts: VLPromptType, **kwargs):
last_round = sess.history[-1]
sess.history[-1] = (prompts, last_round[-1])
return sess
+
+ @classmethod
+ def prompt_to_messages(cls, prompt: VLPromptType):
+ """convert prompt to GTP4V format."""
+ messages = {
+ 'role': 'user',
+ 'content': [{
+ 'type': 'text',
+ 'text': '',
+ }]
+ }
+ if isinstance(prompt, str):
+ messages['content'][0]['text'] = prompt
+ else:
+ prompt, images = prompt
+ if not isinstance(images, list):
+ images = [images]
+ messages['content'][0]['text'] = prompt
+ for image in images:
+ # 'image_url': means url or local path to image.
+ # 'image_data': means PIL.Image.Image object.
+ if isinstance(image, str):
+ image = load_image(image)
+ item = {
+ 'type': 'image_data',
+ 'image_data': {
+ 'data': image
+ }
+ }
+ elif isinstance(image, PIL.Image.Image):
+ item = {
+ 'type': 'image_data',
+ 'image_data': {
+ 'data': image
+ }
+ }
+ else:
+ raise ValueError(
+ 'image should be a str(url/path) or PIL.Image.Image')
+
+ messages['content'].append(item)
+
+ return [messages]
diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py
index 77f0bc8dc8..176c3191f4 100644
--- a/lmdeploy/turbomind/deploy/converter.py
+++ b/lmdeploy/turbomind/deploy/converter.py
@@ -6,7 +6,7 @@
import fire
import torch
-from lmdeploy.archs import get_model_arch
+from lmdeploy.archs import get_model_arch, search_nested_config
from lmdeploy.messages import TurbomindEngineConfig
from lmdeploy.model import MODELS, best_match_model
from lmdeploy.utils import get_logger, get_model
@@ -129,16 +129,17 @@ def get_output_model_registered_name_and_config(model_path: str,
] else 'float16'
elif dtype in ['float16', 'bfloat16']:
if weight_type == 'int4':
- logger.warn(f'The model {model_path} is a quantized model, so the '
- f'specified data type {dtype} is ignored')
+ logger.warning(
+ f'The model {model_path} is a quantized model, so the '
+ f'specified data type {dtype} is ignored')
else:
weight_type = dtype
else:
assert 0, f'unsupported specified data type {dtype}'
if weight_type == 'bfloat16' and not is_bf16_supported():
- logger.warn('data type fallback to float16 since '
- 'torch.cuda.is_bf16_supported is False')
+ logger.warning('data type fallback to float16 since '
+ 'torch.cuda.is_bf16_supported is False')
weight_type = 'float16'
config.model_config.model_arch = model_arch
config.model_config.weight_type = weight_type
@@ -174,23 +175,6 @@ def pack_model_repository(workspace_path: str):
dst=osp.join(model_repo_dir, 'postprocessing'))
-def find_quantization_config(nested, target_key):
- if isinstance(nested, dict):
- for key, value in nested.items():
- if key == target_key:
- return value
- if isinstance(value, (dict, list)):
- result = find_quantization_config(value, target_key)
- if result is not None:
- return result
- elif isinstance(nested, list):
- for item in nested:
- result = find_quantization_config(item, target_key)
- if result is not None:
- return result
- return None
-
-
def get_tm_model(model_path,
model_name,
chat_template_name,
@@ -213,8 +197,7 @@ def get_tm_model(model_path,
If it is None, the turbomind model won't be saved
"""
_, cfg = get_model_arch(model_path)
- quant_config = find_quantization_config(cfg.to_dict(),
- 'quantization_config')
+ quant_config = search_nested_config(cfg.to_dict(), 'quantization_config')
if quant_config:
quant_method = quant_config.get('quant_method')
_group_size = int(quant_config.get('group_size', 0))
diff --git a/lmdeploy/turbomind/deploy/module.py b/lmdeploy/turbomind/deploy/module.py
index 52497175ef..1754161ff5 100644
--- a/lmdeploy/turbomind/deploy/module.py
+++ b/lmdeploy/turbomind/deploy/module.py
@@ -191,7 +191,7 @@ def __init__(self, model: BaseOutputModel):
self.attn_bias = model.model_config.attn_bias
def _reorder_and_merge(self, qkvo):
- q, k, v, o = map(transpose, qkvo)
+ q, k, v, o = qkvo
# reorder output dim for tm's rotary embedding layout
if self.model.permute_qk:
q = permute_v2(q, self.head_dim)
@@ -202,6 +202,27 @@ def _reorder_and_merge(self, qkvo):
o = torch.zeros_like(q)
return qkv, o
+ def _repeat_kv(self, qkvo, kind: str):
+ """replicate kv."""
+ q, k, v, o = qkvo
+ head_dim = self.model.model_config.size_per_head
+ hidden_dim = self.model.model_config.hidden_units
+
+ def _repeat(x):
+ dim = hidden_dim if kind != 'bias' else 1
+ x = x.reshape(dim, -1, head_dim)
+ x = x.repeat(1, 1, self.model.repeat_kv)
+ x = x.reshape(dim, -1)
+ return x
+
+ k, v = map(_repeat, (k, v))
+ if kind == 'bias':
+ if o is None:
+ o = torch.zeros(hidden_dim, dtype=q.dtype, device=q.device)
+ q, k, v, o = map(torch.squeeze, (q, k, v, o))
+
+ return (q, k, v, o)
+
def _export(self, idx: int, qkvo, kind: str, pack_fn, **kwargs):
if all(x is None for x in qkvo):
return
@@ -209,6 +230,9 @@ def _export(self, idx: int, qkvo, kind: str, pack_fn, **kwargs):
if is_lora_a:
qkv, o = map(transpose, qkvo)
else:
+ qkvo = tuple(map(transpose, qkvo))
+ if self.model.repeat_kv:
+ qkvo = self._repeat_kv(qkvo, kind)
qkv, o = self._reorder_and_merge(qkvo)
self.model.save_split(pack_fn(qkv),
self._attn.format(idx, 'w_qkv', kind),
diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py
index f2c981bb24..7ea1a84f35 100644
--- a/lmdeploy/turbomind/deploy/target_model/base.py
+++ b/lmdeploy/turbomind/deploy/target_model/base.py
@@ -78,6 +78,17 @@ def __init__(self,
self.model_config.expert_inter_size = _pad_inter_size(
self.model_config.expert_inter_size,
self.model_config.group_size, self.tensor_para_size)
+
+ # head_num is divisble by tp but kv_head_num is not
+ # and tp is divisble by kv_head_num
+ assert self.model_config.head_num % self.tensor_para_size == 0
+ self.repeat_kv = 0
+ if (self.tensor_para_size > self.model_config.kv_head_num and
+ self.tensor_para_size % self.model_config.kv_head_num == 0):
+ self.repeat_kv = (self.tensor_para_size //
+ self.model_config.kv_head_num)
+ self.model_config.kv_head_num = self.tensor_para_size
+
self.model_config.verify()
assert self.model_config.kv_head_num % self.tensor_para_size == 0
diff --git a/lmdeploy/turbomind/supported_models.py b/lmdeploy/turbomind/supported_models.py
index 11e99edfa0..2b9c5156ed 100644
--- a/lmdeploy/turbomind/supported_models.py
+++ b/lmdeploy/turbomind/supported_models.py
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from lmdeploy.archs import get_model_arch
+from lmdeploy.archs import get_model_arch, search_nested_config
from lmdeploy.utils import get_logger
logger = get_logger('lmdeploy')
@@ -80,7 +80,12 @@ def _is_head_dim_supported(cfg):
if os.path.exists(triton_model_path):
support_by_turbomind = True
else:
+
arch, cfg = get_model_arch(model_path)
+ quant_method = search_nested_config(cfg.to_dict(), 'quant_method')
+ if quant_method and quant_method in ['smooth_quant']:
+ # tm hasn't support quantized models by applying smoothquant
+ return False
if arch in SUPPORTED_ARCHS.keys():
support_by_turbomind = True
diff --git a/lmdeploy/version.py b/lmdeploy/version.py
index d9f4307a78..f705fcb332 100644
--- a/lmdeploy/version.py
+++ b/lmdeploy/version.py
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
-__version__ = '0.6.3'
+__version__ = '0.6.4'
short_version = __version__
diff --git a/lmdeploy/vl/engine.py b/lmdeploy/vl/engine.py
index 124fd537c6..7f786d5f90 100644
--- a/lmdeploy/vl/engine.py
+++ b/lmdeploy/vl/engine.py
@@ -1,13 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
-import inspect
-import queue
-import time
-from threading import Thread
from typing import Dict, List, Optional, Union
import torch
-from PIL.Image import Image
from lmdeploy.messages import (PytorchEngineConfig, TurbomindEngineConfig,
VisionConfig)
@@ -27,169 +22,94 @@ def _raise_exception_on_finish(task: asyncio.Task) -> None:
raise e
-class Record:
- """Batching manager."""
-
- def __init__(self, thread_safe):
- self.thread_safe = thread_safe
- self.number = []
- self.waiting = []
- self.kwargs = []
- self.done = []
- self.res_que = []
- self.total = 0
-
- def enqueue(self, images: List[Image], kwargs: List[Dict],
- que: Union[queue.Queue, asyncio.Queue]):
- """add ith request to manager."""
- self.number.append(len(images))
- self.waiting.extend(images)
- self.kwargs.extend(kwargs)
- self.res_que.append(que)
- self.total += len(images)
- self.log('received', len(images))
-
- def dequeue(self, max_batch_size):
- """try to dequeue max batch size images."""
- inputs = self.waiting[:max_batch_size]
- kwargs = self.kwargs[:max_batch_size]
- self.waiting = self.waiting[max_batch_size:]
- self.kwargs = self.kwargs[max_batch_size:]
- self.total -= len(inputs)
- self.log('process', len(inputs))
- return inputs, kwargs
-
- def notify(self):
- """set result if request i is finished."""
- if len(self.number) == 0 or self.number[0] > len(self.done):
- return False
- num_images = self.number.pop(0)
- outputs = self.done[:num_images]
- self.done = self.done[num_images:]
- que = self.res_que.pop(0)
- self.log('done', num_images)
- if self.thread_safe:
- que._loop.call_soon_threadsafe(que.put_nowait, outputs)
- else:
- que.put_nowait(outputs)
- return True
-
- def log(self, task: str, num: int):
- logger.info(f'ImageEncoder {task} {num} images, '
- f'left {self.total} images.')
-
-
class ImageEncoder:
"""Image encoder."""
- def __init__(self,
- model_path: str,
- vision_config: VisionConfig = None,
- backend_config: Optional[Union[TurbomindEngineConfig,
- PytorchEngineConfig]] = None):
- self.model = load_vl_model(model_path, backend_config=backend_config)
+ def __init__(
+ self,
+ model_path: str,
+ backend: str,
+ vision_config: VisionConfig = None,
+ backend_config: Optional[Union[TurbomindEngineConfig,
+ PytorchEngineConfig]] = None,
+ ):
+ self.model = load_vl_model(model_path,
+ backend,
+ backend_config=backend_config)
if vision_config is None:
vision_config = VisionConfig()
self.vision_config = vision_config
self.max_batch_size = vision_config.max_batch_size
torch.cuda.empty_cache()
- self._que: asyncio.Queue = None
- self._loop_task: asyncio.Task = None
- if vision_config.thread_safe:
- self._create_thread_safe_task()
-
- def _create_thread_safe_task(self):
- """thread safe loop task."""
- self._loop = asyncio.new_event_loop()
- def _work_thread():
- asyncio.set_event_loop(self._loop)
- self._que = asyncio.Queue()
- self._loop.run_until_complete(self._forward_loop())
-
- thread = Thread(target=_work_thread, daemon=True)
- thread.start()
- self._loop_thread = thread
-
- def _create_event_loop_task(self):
- """event loop task."""
- task = asyncio.get_event_loop().create_task(self._forward_loop())
- self._loop_task = task
- self._loop = task.get_loop()
-
- @property
- def req_que(self):
- if self.vision_config.thread_safe:
- return self._que
- if self._que is None:
- self._que = asyncio.Queue()
- if self._loop_task is None:
- self._create_event_loop_task()
- if asyncio.get_event_loop() != self._loop:
- raise RuntimeError('Current event loop is different from'
- ' the one bound to loop task!')
- return self._que
-
- async def _forward_loop(self):
- """working loop to process images."""
- logger.info('start ImageEncoder._forward_loop')
- record = Record(self.vision_config.thread_safe)
- while True:
- while record.total == 0 or (self._que.qsize() and
- record.total < self.max_batch_size):
- while self._que.qsize() == 0:
- await asyncio.sleep(0.01)
- item = await self._que.get()
- record.enqueue(item[0], item[1], item[2])
- inputs, kwargs = record.dequeue(self.max_batch_size)
- future = asyncio.get_event_loop().run_in_executor(
- None, self.forward, inputs, kwargs)
- future.add_done_callback(_raise_exception_on_finish)
- outputs = await future
- record.done.extend(outputs)
- while record.notify():
- pass
-
- def _init_input_params(self,
- inputs: List[Image],
- params: List[Dict] = None):
- """Check and init inputs params."""
- if params is None:
- params = [{}] * len(inputs)
- assert len(params) == len(inputs), \
- 'different length of inputs and kwargs'
- return params
-
- def forward(self, inputs: List[Image], params: List[Dict] = None):
- """Model forward."""
- params = self._init_input_params(inputs, params)
- time_start = time.perf_counter()
- func_params = inspect.signature(self.model.forward).parameters
- func_inputs = [inputs, params] if len(func_params) > 1 else [inputs]
- outputs = self.model.forward(*func_inputs)
- if isinstance(outputs[0], torch.Tensor):
- outputs = [x.cpu() for x in outputs]
- time_end = time.perf_counter()
- logger.info(f'ImageEncoder forward {len(inputs)} images, '
- f'cost {time_end - time_start:.3f}s')
+ async def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """preprocess multimodal data in the messages."""
+ future = asyncio.get_event_loop().run_in_executor(
+ None, self.model.preprocess, messages)
+ future.add_done_callback(_raise_exception_on_finish)
+ outputs = await future
return outputs
- def infer(self, inputs: List[Image], params: List[Dict] = None):
- """infer."""
- params = self._init_input_params(inputs, params)
- results = self.forward(inputs, params)
- return results
+ async def async_infer(self, messages: List[Dict]) -> List[Dict]:
+ """get multimodal embedding.
+
+ Args:
+ messages (List[Dict]): a list of message, which is the output
+ of `preprocess()`
+ """
+ future = asyncio.get_event_loop().run_in_executor(
+ None, self.model.forward, messages, self.max_batch_size)
+ future.add_done_callback(_raise_exception_on_finish)
+ outputs = await future
+ return outputs
- async def async_infer(self,
- inputs: List[Image],
- params: List[Dict] = None):
- """async infer."""
- params = self._init_input_params(inputs, params)
- outputs = asyncio.Queue()
- item = (inputs, params, outputs)
- if self.vision_config.thread_safe:
- self._loop.call_soon_threadsafe(self._que.put_nowait, item)
- else:
- self.req_que.put_nowait(item)
- results = await outputs.get()
- return results
+ async def wrap_for_pytorch(self, messages: List[Dict], chat_template,
+ tokenizer, sequence_start) -> List[Dict]:
+ """
+ Args:
+ messages (List[Dict]): a list of message, which is supposed to be
+ the output of `preprocess`
+ Returns:
+ a dict which will be passed to pytorch engine_instance's forward.
+ The dict is like the following:
+ Dict(
+ 'prompt': 'the prompt after applying chat template'
+ 'input_ids': [],
+ 'multimodal': {
+ 'pixel_values': torch.Tensor,
+ ...
+ ]
+ )
+ """
+ result = self.model.to_pytorch(messages, chat_template, tokenizer,
+ sequence_start)
+ # clear data
+ for i, message in enumerate(messages):
+ if isinstance(message['content'], List):
+ messages[i]['preprocess'] = None
+ return result
+
+ async def wrap_for_turbomind(self, messages: List[Dict], chat_template,
+ tokenizer, sequence_start) -> Dict:
+ """
+ Args:
+ messages (List[Dict]): a list of message, which is supposed to be
+ the output of `async_infer`
+ Returns:
+ a dict which will be passed to pytorch engine_instance's forward.
+ The dict is like the following:
+ Dict(
+ 'prompt': 'the prompt after applying chat template'
+ 'input_ids': [],
+ 'input_embeddings': list[torch.Tensor],
+ 'input_embedding_ranges': list[torch.Tensor],
+ ...
+ """
+ result = self.model.to_turbomind(messages, chat_template, tokenizer,
+ sequence_start)
+ # clear data
+ for i, message in enumerate(messages):
+ if isinstance(message['content'], List):
+ messages[i]['preprocess'] = None
+ messages[i]['forward'] = None
+ return result
diff --git a/lmdeploy/vl/model/base.py b/lmdeploy/vl/model/base.py
index 9c5f5f6e6a..0ee22b4688 100644
--- a/lmdeploy/vl/model/base.py
+++ b/lmdeploy/vl/model/base.py
@@ -2,8 +2,7 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Union
-import PIL
-import torch
+import numpy as np
from mmengine import Registry
from transformers import AutoConfig
@@ -20,35 +19,227 @@ def __init__(self,
model_path: str,
with_llm: bool = False,
max_memory: Dict[int, int] = None,
- hf_config: AutoConfig = None):
+ hf_config: AutoConfig = None,
+ backend: str = ''):
"""init."""
self.model_path = model_path
self.with_llm = with_llm
self.max_memory = max_memory
+ self.backend = backend
if hf_config is None:
_, hf_config = get_model_arch(model_path)
self.hf_config = hf_config
- self.build_model()
@abstractmethod
- def build_model():
- """build model."""
+ def build_preprocessor(self, ):
+ """build the preprocessor.
+
+ NOTE: When the derived class implements this method, try not to
+ introduce the upper stream model repo as a thirdparty package
+ """
raise NotImplementedError()
+ def build_model(self, ):
+ """build the vision part of a VLM model when backend is turbomind.
+
+ But when `with_llm=True`, load the whole VLM model
+ """
+ if self.backend == 'turbomind' or self.with_llm:
+ raise NotImplementedError()
+
@abstractmethod
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """preprocess multimodal data in the messages. The derived class,
+ i.e., a specific vision model, takes the charge of image preprocessing
+ and the result management.
+ It can integrate the result into the messages list, or insert it to
+ the individual image item.
+ Args:
+ message(Dict): multimodal data in a dict, which is as follows:
+ [
+ {'role': 'user', 'content': 'user prompt'},
+ {'role': 'assisant', 'content': 'AI reponse'},
+ {
+ 'role': 'user',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'string',
+ },
+ {
+ 'type': 'image',
+ 'image': pillow.Image,
+ 'key1': value1,
+ ...
+ },
+ {
+ 'type': 'image',
+ 'image': pillow.Image,
+ 'key1': value1,
+ ...
+ },
+ ...
+ ]
+ }
+ {....}
+ ]
+ Returns:
+ the message list with preprocessing results included, which is
+ determined by the derived classes
+ """ # noqa
+ raise NotImplementedError()
+
def forward(self,
- images: List[PIL.Image.Image],
- image_kwargs: List[Dict] = None) -> List[torch.Tensor]:
- """extract image feature.
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
Args:
- images (List[PIL.Image.Image]): input images
- image_kwargs (List[Dict]): input kwargs for each images
-
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
Return:
- List[torch.Tensor]: extract image feature for each input image
+ the message list with forwarding results included, which is
+ determined by the derived classes
"""
- raise NotImplementedError()
+ if self.backend == 'turbomind':
+ raise NotImplementedError()
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ """pack the preprocessing results in a format compatible with what is
+ required by pytorch engine. ONLY implement it when the backend is
+ pytorch engine.
+
+ Args:
+ messages(List[Dict]): the output of `preprocess`
+ chat_template: the chat template defined in `lmdeploy/model.py`
+ tokenzer: the tokenizer model
+ sequence_start: starting flag of a sequence
+ """
+ if self.backend == 'pytorch':
+ raise NotImplementedError()
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ """pack the forwarding results in a format compatible with what is
+ required by turbomind engine. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the output of `preprocess`
+ chat_template: the chat template defined in `lmdeploy/model.py`
+ tokenzer: the tokenizer model
+ sequence_start: starting flag of a sequence
+ """
+ if self.backend == 'turbomind':
+ raise NotImplementedError()
+
+ @staticmethod
+ def collect_images(messages):
+ """gather all images along with their respective parameters from the
+ messages and compile them into a single list. Each image is converted
+ to RGB color space.
+
+ Args:
+ messages (List[Tuple[Image, Dict]]): a list of images with their
+ corresponding parameters
+ """ # noqa
+ images = []
+ for message in messages:
+ content = message['content']
+ if not isinstance(content, List):
+ continue
+ images.extend([
+ (x['image'],
+ {k: v
+ for k, v in x.items() if k not in {'type', 'image'}})
+ for x in content if x['type'] == 'image'
+ ])
+ return images
+
+ @staticmethod
+ def to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start):
+ """auxiliary function to pack the preprocessing results in a format
+ compatible with what is required by pytorch engine.
+
+ Args:
+ messages(List[Dict]): the output of `preprocess`
+ prompt(str): the prompt after applying chat template
+ IMAGE_TOKEN(str): a placeholder where image tokens will be
+ inserted
+ tokenzer: the tokenizer model
+ sequence_start: starting flag of a sequence
+ """
+ # collect all preprocessing result from messages
+ preps = [x['content'] for x in messages if x['role'] == 'preprocess']
+ assert len(preps) == 1
+ preps = preps[0]
+
+ # split prompt into segments and validate data
+ segs = prompt.split(IMAGE_TOKEN)
+ assert len(segs) == len(preps) + 1, (
+ f'the number of {IMAGE_TOKEN} is not equal '
+ f'to input images, {len(segs) - 1} vs {len(preps)}')
+
+ # calculate the image token offset for each image
+ input_ids = []
+ for i, seg in enumerate(segs):
+ if i > 0 and i <= len(preps):
+ preps[i - 1].update(offset=len(input_ids))
+ image_tokens = preps[i - 1]['image_tokens']
+ image_token_id = preps[i - 1]['image_token_id']
+ input_ids.extend([image_token_id] * image_tokens)
+ token_ids = tokenizer.encode(seg,
+ add_bos=((i == 0) and sequence_start))
+ input_ids.extend(token_ids)
+
+ return dict(prompt=prompt, input_ids=input_ids, multimodal=preps)
+
+ @staticmethod
+ def to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start):
+ """auxiliary function to pack the forwarding results in a format
+ compatible with what is required by turbomind engine.
+
+ Args:
+ messages(List[Dict]): the output of `preprocess`
+ prompt(str): the prompt after applying chat template
+ IMAGE_TOKEN(str): a placeholder where image tokens will be
+ inserted
+ tokenzer: the tokenizer model
+ sequence_start: starting flag of a sequence
+ """
+ # collect image features from messages
+ features = [x['content'] for x in messages if x['role'] == 'forward']
+ features = features[0]
+ features = [x.cpu().numpy() for x in features]
+
+ # split prompt into segments and validate data
+ segs = prompt.split(IMAGE_TOKEN)
+ assert len(segs) == len(features) + 1, (
+ f'the number of {IMAGE_TOKEN} is not equal '
+ f'to input images, {len(segs) - 1} vs {len(features)}')
+
+ # tokenizer prompt, and get input_embeddings and input_embedding_ranges
+ input_ids = []
+ begins = []
+ ends = []
+ IMAGE_DUMMY_TOKEN_INDEX = 0
+ for i, seg in enumerate(segs):
+ if i > 0 and i <= len(features):
+ image_dim = features[i - 1].shape[0]
+ begins.append(len(input_ids))
+ ends.append(begins[-1] + image_dim)
+ input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim)
+ seg_ids = tokenizer.encode(seg,
+ add_bos=((i == 0) and sequence_start))
+ input_ids.extend(seg_ids)
+ ranges = np.stack([begins, ends], axis=1).tolist()
+ return dict(prompt=prompt,
+ input_ids=input_ids,
+ input_embeddings=features,
+ input_embedding_ranges=ranges)
@classmethod
def match(cls, config: AutoConfig):
diff --git a/lmdeploy/vl/model/builder.py b/lmdeploy/vl/model/builder.py
index 2401b42259..00e668c034 100644
--- a/lmdeploy/vl/model/builder.py
+++ b/lmdeploy/vl/model/builder.py
@@ -2,6 +2,8 @@
import os
from typing import Optional, Union
+import torch
+
from lmdeploy.archs import get_model_arch
from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig
from lmdeploy.utils import get_logger, get_model
@@ -29,6 +31,7 @@
def load_vl_model(model_path: str,
+ backend: str,
with_llm: bool = False,
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None):
@@ -36,8 +39,9 @@ def load_vl_model(model_path: str,
Args:
model_path(str): the path or repo_id from model hub of the model
- with_llm(bool): whether to remove the LLM part from the model.
- When it is False, it means removing LLM part
+ backend(str): the name of inference backend
+ with_llm(bool): load LLM model or not. Set it to False for VLM
+ inference scenarios and True for VLM quantization
backend_config: the config of the inference engine
"""
if not os.path.exists(model_path):
@@ -49,7 +53,6 @@ def load_vl_model(model_path: str,
max_memory = None
if not with_llm:
- import torch
tp = getattr(backend_config, 'tp', 1)
max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(tp)}
@@ -57,30 +60,21 @@ def load_vl_model(model_path: str,
kwargs = dict(model_path=model_path,
with_llm=with_llm,
max_memory=max_memory,
- hf_config=hf_config)
+ hf_config=hf_config,
+ backend=backend)
for name, module in VISION_MODELS.module_dict.items():
try:
if module.match(hf_config):
logger.info(f'matching vision model: {name}')
- return module(**kwargs)
- except Exception:
- logger.error(f'matching vision model: {name} failed')
+ model = module(**kwargs)
+ model.build_preprocessor()
+ # build the vision part of a VLM model when backend is
+ # turbomind, or load the whole VLM model when `with_llm==True`
+ if backend == 'turbomind' or with_llm:
+ model.build_model()
+ return model
+ except Exception as e:
+ logger.error(f'build vision model {name} failed, {e}')
raise
raise ValueError(f'unsupported vl model with config {hf_config}')
-
-
-def vl_model_with_tokenizer(model_path: str, with_llm: bool = True):
- """load visual model."""
- vl_model = load_vl_model(model_path, with_llm).vl_model
- llm = vl_model
- if hasattr(vl_model, 'language_model'): # deepseek vl
- llm = vl_model.language_model
- if hasattr(vl_model, 'llm'): # MiniCPMV
- llm = vl_model.llm
- llm.config.use_cache = False
- llm.half().eval()
- from transformers import AutoTokenizer
- tokenizer = AutoTokenizer.from_pretrained(model_path,
- trust_remote_code=True)
- return vl_model, llm, tokenizer
diff --git a/lmdeploy/vl/model/cogvlm.py b/lmdeploy/vl/model/cogvlm.py
index ea5a06159e..07d97153f9 100644
--- a/lmdeploy/vl/model/cogvlm.py
+++ b/lmdeploy/vl/model/cogvlm.py
@@ -1,13 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
-import warnings
-from typing import List
-
-import torch
-from PIL.Image import Image
-from transformers import AutoModelForCausalLM
+from typing import Dict, List
+from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
-from lmdeploy.vl.model.utils import disable_logging
+
+logger = get_logger('lmdeploy')
@VISION_MODELS.register_module()
@@ -16,7 +13,7 @@ class CogVLMVisionModel(VisonModel):
_arch = 'CogVLMForCausalLM'
- def build_model(self):
+ def build_preprocessor(self):
from torchvision import transforms
self.image_transform = transforms.Compose([
transforms.Resize(
@@ -26,57 +23,73 @@ def build_model(self):
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
+ image_size = self.hf_config.vision_config['image_size']
+ patch_size = self.hf_config.vision_config['patch_size']
+ self.n_token_per_image = 2 + (image_size // patch_size // 2)**2
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
- from accelerate.utils import get_balanced_memory, infer_auto_device_map
- with init_empty_weights(), warnings.catch_warnings():
- model = AutoModelForCausalLM.from_config(self.hf_config,
- trust_remote_code=True)
- if not self.with_llm:
- del model.lm_head
- for key in ['layers', 'norm', 'embed_tokens']:
- setattr(model.model, key, None)
- else:
- self.vl_model = model
+ def build_model(self):
+ if self.with_llm:
+ from transformers import AutoModelForCausalLM
+ self.vl_model = AutoModelForCausalLM.from_pretrained(
+ self.model_path, device_map='cpu', trust_remote_code=True)
+ else:
+ raise NotImplementedError('turbomind has not supported cogvlm yet')
- no_split_module_classes = ['TransformerLayer']
- max_memory = get_balanced_memory(
- model,
- max_memory=self.max_memory,
- dtype=torch.half,
- no_split_module_classes=no_split_module_classes)
- device_map = infer_auto_device_map(
- model,
- no_split_module_classes=no_split_module_classes,
- max_memory=max_memory,
- dtype=torch.half)
- same_device_keys = [('model.vision.linear_proj', 'model.vision.boi',
- 'model.vision.eoi')]
- for keys in same_device_keys:
- keys = [k for k in keys if k in device_map]
- if len(keys) <= 1:
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to the spec of `super().preprocess`"""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, _ in images:
+ image = image.convert('RGB')
+ pixel_values = self.image_transform(image)
+ outputs.append(
+ dict(pixel_values=pixel_values,
+ image_size=image.size,
+ image_tokens=self.n_token_per_image,
+ image_token_id=0))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
continue
- for k in keys[1:]:
- device_map[k] = device_map[keys[0]]
+ content = [
+ x['text'] for x in message['content'] if x['type'] == 'text'
+ ]
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+
+ prompt_messages.append(
+ dict(role='user', content=content[0], num_images=n_images))
- with disable_logging():
- load_checkpoint_and_dispatch(
- model=model,
- checkpoint=self.model_path,
- device_map=device_map if not self.with_llm else {'': 'cpu'},
- no_split_module_classes=no_split_module_classes,
- dtype=torch.half)
- self.model = model.model.vision
- self.model.eval()
+ from lmdeploy.model import Vicuna
+ llm_chat_template = Vicuna(eoa=chat_template.eoa,
+ stop_words=chat_template.stop_words)
+ prompt = ''
+ IMAGE_TOKEN = ''
+ for i, msg in enumerate(prompt_messages):
+ num_images = msg.pop('num_images', 0)
+ if num_images == 0:
+ role = msg['role']
+ msg = llm_chat_template.messages2prompt([msg], sequence_start
+ and i == 0)
+ msg = dict(role=role, content=msg)
+ prompt_i = chat_template.messages2prompt([msg], sequence_start
+ and i == 0)
+ if num_images > 0:
+ prompt_i = (IMAGE_TOKEN * num_images) + prompt_i
+ prompt += prompt_i
+ return prompt, IMAGE_TOKEN
- @torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- outputs = [x.convert('RGB') for x in images]
- outputs = [self.image_transform(x) for x in outputs]
- outputs = torch.stack(outputs, dim=0).to(device='cuda:0',
- dtype=torch.half)
- outputs = self.model(outputs)
- outputs = torch.split(outputs, 1, dim=0)
- outputs = [x.squeeze() for x in outputs]
- return outputs
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/deepseek.py b/lmdeploy/vl/model/deepseek.py
index bfbf03f01e..9780744cf2 100644
--- a/lmdeploy/vl/model/deepseek.py
+++ b/lmdeploy/vl/model/deepseek.py
@@ -1,15 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
-
import warnings
-from typing import List
+from typing import Dict, List
import torch
-from PIL.Image import Image
from transformers import AutoModelForCausalLM
+from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import disable_logging
+logger = get_logger('lmdeploy')
+
def check_deepseek_vl_install():
"""check deepseek_vl install."""
@@ -18,8 +19,8 @@ def check_deepseek_vl_install():
except ImportError:
raise ImportError(
'To use DeepSeekVLModel, please install deepseek_vl by '
- 'pip install git+https://github.com/deepseek-ai/DeepSeek-VL.git'
- ' --no-deps')
+ '`pip install git+https://github.com/deepseek-ai/DeepSeek-VL.git'
+ ' --no-deps`')
@VISION_MODELS.register_module()
@@ -28,18 +29,22 @@ class DeepSeekVisionModel(VisonModel):
_arch = 'MultiModalityCausalLM'
- def build_model(self):
+ def build_preprocessor(self):
check_deepseek_vl_install()
- # empty init
- from accelerate import init_empty_weights
from deepseek_vl.models import VLChatProcessor
+ self.image_processor = VLChatProcessor.from_pretrained(
+ self.model_path).image_processor
+
+ def build_model(self):
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
+ from accelerate import init_empty_weights
with init_empty_weights():
warnings.simplefilter('ignore')
model = AutoModelForCausalLM.from_pretrained(self.model_path)
+ self.vl_model = model
if not self.with_llm:
del model.language_model
- else:
- self.vl_model = model
from accelerate.utils import get_balanced_memory, infer_auto_device_map
max_memory = get_balanced_memory(model,
@@ -79,23 +84,111 @@ def build_model(self):
device_map=device_map if not self.with_llm else {'': 'cpu'},
dtype=torch.half)
+ self.model = model.eval()
self.vision_model = model.vision_model.eval()
self.aligner = model.aligner.eval()
- self.image_processor = VLChatProcessor.from_pretrained(
- self.model_path).image_processor
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refers to the spec of `super.preprocess()"""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, _ in images:
+ image = image.convert('RGB')
+ pixel_values = self.image_processor(
+ [image], return_tensors='pt').pixel_values
+ outputs.append(
+ dict(
+ pixel_values=pixel_values,
+ image_size=image.size,
+ # refer to https://github.com/deepseek-ai/DeepSeek-VL/blob/main/deepseek_vl/models/processing_vlm.py # noqa
+ # which is hardcoded 576
+ image_tokens=576,
+ image_token_id=0))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
@torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- outputs = [x.convert('RGB') for x in images]
- pixel_values = self.image_processor(outputs,
- return_tensors='pt').pixel_values
- pixel_values = pixel_values.to(device=next(
- self.vision_model.parameters()).device,
- dtype=torch.float16)
- # [b x n_images, T2, D]
- images_embeds = self.aligner(self.vision_model(pixel_values))
-
- outputs = torch.split(images_embeds, 1, dim=0)
- outputs = [x.squeeze() for x in outputs]
- return outputs
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ pixel_values = torch.cat(pixel_values, dim=0)
+ pixel_values = pixel_values.to(device=next(
+ self.vision_model.parameters()).device,
+ dtype=torch.float16)
+ # [b x n_images, T2, D]
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ feats = self.aligner(self.vision_model(pixel_values))
+ feats = torch.split(feats, 1, dim=0)
+ outputs.extend([x.squeeze() for x in feats])
+ messages.append(dict(role='forward', content=outputs))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ # apply chat template to get the prompt
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
+ continue
+ content = [
+ x['text'] for x in message['content'] if x['type'] == 'text'
+ ]
+ content = content[0]
+ n_image = sum(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ n_placeholder = content.count(IMAGE_TOKEN)
+ if n_placeholder == 0:
+ logger.warning(
+ f"""for deepseek-vl model, the user should insert the {IMAGE_TOKEN}
+ to user prompt manually, please read https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html
+ for more details.""") # noqa
+ if n_placeholder != 0 and n_placeholder != n_image:
+ logger.error(
+ f'unmatched placeholder and image: {n_placeholder} vs '
+ f'{n_image}. Ignore the placeholder')
+ content = content.replace(IMAGE_TOKEN, '')
+ n_placeholder = 0
+ if n_placeholder == 0:
+ if n_image == 1:
+ content = f'{IMAGE_TOKEN}{content}'
+ else:
+ content = ''.join([
+ f'{IMAGE_TOKEN} is Figure {str(i)}.\n'
+ for i in range(n_image)
+ ]) + content
+ prompt_messages.append(dict(role='user', content=content))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/glm_4v.py b/lmdeploy/vl/model/glm_4v.py
index 34e060f4c9..813813bf09 100644
--- a/lmdeploy/vl/model/glm_4v.py
+++ b/lmdeploy/vl/model/glm_4v.py
@@ -1,77 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List
-import warnings
-from typing import List
-
-import torch
-from PIL.Image import Image
from transformers import AutoConfig
+from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
-from lmdeploy.vl.model.utils import disable_logging
+
+logger = get_logger('lmdeploy')
@VISION_MODELS.register_module()
class GLM4VisionModel(VisonModel):
"""glm-4v-9b vision model."""
- _arch = 'ChatGLMModel'
+ _arch = ['ChatGLMModel', 'ChatGLMForConditionalGeneration']
@classmethod
def match(cls, config: AutoConfig):
"""check whether the config match the model."""
arch = config.architectures[0]
- if arch == cls._arch and hasattr(config, 'vision_config'):
+ if arch in cls._arch and hasattr(config, 'vision_config'):
return True
return False
- def build_model(self):
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
- from accelerate.utils import infer_auto_device_map
+ def build_preprocessor(self):
from torchvision import transforms
-
- with init_empty_weights(), warnings.catch_warnings():
- warnings.simplefilter('ignore')
- from transformers import AutoModelForCausalLM
- model = AutoModelForCausalLM.from_config(self.hf_config,
- trust_remote_code=True)
- if not self.with_llm:
- del model.transformer.embedding
- del model.transformer.rotary_pos_emb
- del model.transformer.encoder
- del model.transformer.output_layer
- else:
- self.vl_model = model
-
- no_split_module_classes = ['TransformerLayer']
-
- device_map = infer_auto_device_map(
- model,
- no_split_module_classes=no_split_module_classes,
- max_memory=self.max_memory,
- dtype=torch.half)
-
- same_device_keys = [
- ('transformer.vision.linear_proj', 'transformer.vision.boi',
- 'transformer.vision.eoi')
- ]
- for keys in same_device_keys:
- keys = [k for k in keys if k in device_map]
- if len(keys) <= 1:
- continue
- for k in keys[1:]:
- device_map[k] = device_map[keys[0]]
-
- with disable_logging():
- load_checkpoint_and_dispatch(
- model=model,
- checkpoint=self.model_path,
- device_map=device_map if not self.with_llm else {'': 'cpu'},
- no_split_module_classes=no_split_module_classes,
- dtype=torch.half)
-
- model.eval()
- self.model = model
self.image_transform = transforms.Compose([
transforms.Resize(
(self.hf_config.vision_config['image_size'], ) * 2,
@@ -80,15 +33,65 @@ def build_model(self):
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
+ image_size = self.hf_config.vision_config['image_size']
+ patch_size = self.hf_config.vision_config['patch_size']
+ self.n_token_per_image = 2 + (image_size // patch_size // 2)**2
+
+ def build_model(self):
+ if self.with_llm:
+ from transformers import AutoModelForCausalLM
+ self.vl_model = AutoModelForCausalLM.from_pretrained(
+ self.model_path, device_map='cpu', trust_remote_code=True)
+ else:
+ raise NotImplementedError('turbomind has not supported glm4v yet')
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refers to the spec of `super.preprocess()"""
+ outputs = []
+ for message in messages:
+ if not isinstance(message['content'], List):
+ continue
+ images = [
+ x['image'] for x in message['content'] if x['type'] == 'image'
+ ]
+ if len(images) > 1:
+ logger.warning(
+ f'glm4v does not support the input of multiple images'
+ f' in a single chat round, but got {len(images)} images.')
+ # we still pass all the images to the model and let the
+ # model decide what to do
+ images = [x.convert('RGB') for x in images]
+ pixel_values = [self.image_transform(x) for x in images]
+ outputs.extend([
+ dict(pixel_values=_2,
+ image_size=_1.size,
+ image_tokens=self.n_token_per_image,
+ image_token_id=0) for _1, _2 in zip(images, pixel_values)
+ ])
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ content = message['content']
+ if isinstance(content, str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['preprocess', 'forward']:
+ continue
+ prompt = [x['text'] for x in content if x['type'] == 'text']
+ n_images = len([1 for x in content if x['type'] == 'image'])
+ prompt = ''.join([f'{IMAGE_TOKEN}\n'] * n_images) + prompt[0]
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
- @torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- outputs = [x.convert('RGB') for x in images]
- outputs = [self.image_transform(x) for x in outputs]
- outputs = torch.stack(outputs, dim=0).to(device='cuda:0',
- dtype=torch.half)
- outputs = self.model.transformer.vision(outputs)
- outputs = torch.split(outputs, 1, dim=0)
- outputs = [x.squeeze() for x in outputs]
- return outputs
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/internvl.py b/lmdeploy/vl/model/internvl.py
index fa67192f11..979b8d1a39 100644
--- a/lmdeploy/vl/model/internvl.py
+++ b/lmdeploy/vl/model/internvl.py
@@ -1,10 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
-
from typing import Dict, List
import torch
-from PIL.Image import Image
-from transformers import AutoModel, CLIPImageProcessor
+from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
@@ -80,34 +78,16 @@ class InternVLVisionModel(VisonModel):
_arch = 'InternVLChatModel'
- def build_model(self):
- """Load model."""
- from accelerate import init_empty_weights
- with init_empty_weights():
- config = self.hf_config
- # transformers below 4.37.0 may raise error about flash_attn
- config.llm_config.attn_implementation = 'eager'
- model = AutoModel.from_config(config, trust_remote_code=True)
- if not self.with_llm:
- del model.language_model
- else:
- self.vl_model = model
- model.half()
+ def __init__(self,
+ model_path: str,
+ with_llm: bool = False,
+ max_memory: Dict[int, int] = None,
+ hf_config: AutoConfig = None,
+ backend: str = ''):
+ super().__init__(model_path, with_llm, max_memory, hf_config, backend)
- from accelerate import load_checkpoint_and_dispatch
- with disable_logging():
- load_checkpoint_and_dispatch(
- model=model,
- checkpoint=self.model_path,
- device_map='auto' if not self.with_llm else {'': 'cpu'},
- max_memory=self.max_memory,
- no_split_module_classes=['InternVisionEncoderLayer'],
- dtype=torch.half)
-
- # We need eval mode to freeze the weights in model, thus,
- # avoid randomness in inference.
- self.model = model.eval()
- self.config = config
+ def build_preprocessor(self):
+ self.config = self.hf_config
dynamic_image_size = getattr(self.config, 'dynamic_image_size', False)
image_processor = None
try:
@@ -131,62 +111,180 @@ def build_model(self):
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
+ self.processor = self._preprocess_v1_5
self._forward_func = self._forward_v1_5
else:
+ self.processor = self._preprocess
self.image_processor = image_processor
self._forward_func = self._forward
- def _preprocess_v1_5(self, images: List[Image], params: List[Dict] = None):
- if params is not None:
- assert len(images) == len(
- params), 'different length of images and params'
- else:
- params = [{}] * len(images)
+ force_image_size = self.hf_config.force_image_size
+ patch_size = self.hf_config.vision_config.patch_size
+ downsample_ratio = self.hf_config.downsample_ratio
+ self.image_tokens_per_patch = int(
+ (force_image_size // patch_size)**2 * (downsample_ratio**2))
- image_res = {'low': 6, 'medium': 12, 'high': 24}
+ def build_model(self):
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
+ from accelerate import init_empty_weights
+ with init_empty_weights():
+ # transformers below 4.37.0 may raise error about flash_attn
+ self.config.llm_config.attn_implementation = 'eager'
+ model = AutoModel.from_config(self.config, trust_remote_code=True)
+ self.vl_model = model
+ if not self.with_llm:
+ del model.language_model
- outputs = []
- for image, param in zip(images, params):
- max_num = param.get('max_dynamic_patch')
- if max_num is None or not isinstance(max_num, int):
- res_key = param.get('detail', 'default')
- max_num = image_res.get(res_key, self.config.max_dynamic_patch)
- out = dynamic_preprocess(
- image,
- min_num=self.config.min_dynamic_patch,
- max_num=max_num,
- image_size=self.config.vision_config.image_size,
- use_thumbnail=self.config.use_thumbnail)
- out = [self.transform(x) for x in out]
- out = torch.stack(out) # (patch) x c x h x w
- outputs.append(out)
- return outputs
+ model.half()
+ from accelerate import load_checkpoint_and_dispatch
+ with disable_logging():
+ load_checkpoint_and_dispatch(
+ model=model,
+ checkpoint=self.model_path,
+ device_map='auto' if not self.with_llm else {'': 'cpu'},
+ max_memory=self.max_memory,
+ no_split_module_classes=['InternVisionEncoderLayer'],
+ dtype=torch.half)
+
+ # We need eval mode to freeze the weights in model, thus,
+ # avoid randomness in inference.
+ self.model = model.eval()
+
+ def _preprocess_v1_5(self, image, params=None):
+ image_res = {'low': 6, 'medium': 12, 'high': 24}
+ max_num = params.get('max_dynamic_patch')
+ if max_num is None or not isinstance(max_num, int):
+ res_key = params.get('detail', 'default')
+ max_num = image_res.get(res_key, self.config.max_dynamic_patch)
+ out = dynamic_preprocess(
+ image,
+ min_num=self.config.min_dynamic_patch,
+ max_num=max_num,
+ image_size=self.config.vision_config.image_size,
+ use_thumbnail=self.config.use_thumbnail)
+ pixel_values = [self.transform(x) for x in out]
+ # (patch) x c x h x w
+ pixel_values = torch.stack(pixel_values)
+ return pixel_values
- def _forward_v1_5(self, images: List[Image], params: List[Dict] = None):
+ def _forward_v1_5(self, inputs, max_batch_size):
"""forward for internvl-chat-v1-5."""
- outputs = self._preprocess_v1_5(images, params)
- split = [x.shape[0] for x in outputs]
- outputs = torch.cat(outputs, dim=0)
- outputs = outputs.to(self.model.device, dtype=torch.float16)
- outputs = self.model.extract_feature(outputs)
- outputs = torch.split(outputs, split, dim=0)
- outputs = [x.reshape(-1, x.shape[-1]) for x in outputs]
+ assert all(x.get('pixel_values') is not None for x in inputs)
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ split = [x.shape[0] for x in pixel_values]
+ pixel_values = torch.cat(pixel_values, dim=0)
+ pixel_values = pixel_values.to(self.model.device,
+ dtype=torch.float16)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ feats = self.model.extract_feature(pixel_values)
+ feats = torch.split(feats, split, dim=0)
+ outputs.extend([x.reshape(-1, x.shape[-1]) for x in feats])
return outputs
- def _forward(self, images: List[Image], params: List[Dict] = None):
+ def _preprocess(self, image, params=None):
"""forward for internvl-chat-v1-1, internvl-chat-v1-2."""
- pixel_values = self.image_processor(images=images,
+ pixel_values = self.image_processor(images=image,
return_tensors='pt').pixel_values
- pixel_values = pixel_values.to(self.model.device, dtype=torch.float16)
- outputs = self.model.extract_feature(pixel_values)
- outputs = torch.split(outputs, 1, dim=0)
- outputs = [x.squeeze() for x in outputs]
+ return pixel_values
+
+ def _forward(self, inputs, max_batch_size):
+ """forward for internvl-chat-v1-1, internvl-chat-v1-2."""
+ assert all(x.get('pixel_values') is not None for x in inputs)
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ pixel_values = torch.cat(pixel_values, dim=0)
+ pixel_values = pixel_values.to(self.model.device,
+ dtype=torch.float16)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ feats = self.model.extract_feature(pixel_values)
+ feats = torch.split(feats, 1, dim=0)
+ outputs.extend([x.squeeze() for x in feats])
return outputs
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refers to `super.preprocess() for spec."""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ pixel_values = self.processor(image, params)
+ image_tokens = (pixel_values.shape[0] *
+ self.image_tokens_per_patch)
+ outputs.append(
+ dict(pixel_values=pixel_values,
+ image_tokens=image_tokens,
+ image_token_id=0,
+ image_size=image.size))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
+
@torch.no_grad()
def forward(self,
- images: List[Image],
- params: List[Dict] = None) -> List[torch.Tensor]:
- """forward."""
- images = [x.convert('RGB') for x in images]
- return self._forward_func(images, params)
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = self._forward_func(inputs, max_batch_size)
+ messages.append(dict(role='forward', content=outputs))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['preprocess', 'forward']:
+ continue
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ content = [
+ x['text'] for x in message['content'] if x['type'] == 'text'
+ ]
+ prompt = content[0]
+ if IMAGE_TOKEN in prompt and f'{IMAGE_TOKEN}' not in prompt:
+ prompt = prompt.replace(f'{IMAGE_TOKEN}',
+ f'{IMAGE_TOKEN}')
+ prompt = prompt.replace('', '')
+ prompt = prompt.replace('', '')
+ prompt = prompt.replace('', '')
+ elif IMAGE_TOKEN not in prompt:
+ prompt = f'{IMAGE_TOKEN * n_images}\n' + prompt
+ else:
+ pass
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/internvl_llava.py b/lmdeploy/vl/model/internvl_llava.py
index f607082b18..17a12f71ca 100644
--- a/lmdeploy/vl/model/internvl_llava.py
+++ b/lmdeploy/vl/model/internvl_llava.py
@@ -2,14 +2,13 @@
import warnings
from contextlib import contextmanager
-from typing import List, Union
+from typing import Dict, List
import torch
-from PIL.Image import Image
from transformers import AutoConfig, AutoModelForCausalLM
from lmdeploy.utils import get_logger
-from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
+from lmdeploy.vl.model.llava import VISION_MODELS, LlavaVisionModel
from lmdeploy.vl.model.utils import rewrite_ctx
from .utils import disable_logging, disable_transformers_logging
@@ -18,14 +17,13 @@
def check_llava_install():
- """check llava install."""
try:
from llava.model.multimodal_encoder.clip_encoder import \
InternVisionModel # noqa: F401
except ImportError:
raise ImportError(
'To use LlavaVLModel, please install llava by '
- 'pip install "git+https://github.com/OpenGVLab/InternVL#subdirectory=internvl_chat_llava" --no-deps' # noqa: E501
+ '`pip install git+https://github.com/OpenGVLab/InternVL#subdirectory=internvl_chat_llava --no-deps`' # noqa: E501
)
@@ -65,7 +63,7 @@ def init_empty_vit():
@VISION_MODELS.register_module()
-class InternVLLlavaVisionModel(VisonModel):
+class InternVLLlavaVisionModel(LlavaVisionModel):
"""Llava visual model."""
@classmethod
@@ -78,9 +76,12 @@ def match(cls, config: AutoConfig):
return True
return False
+ def build_preprocessor(self):
+ return super().build_preprocessor()
+
def build_model(self):
- """build model & load weights."""
- # check llava install
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
check_llava_install()
# currently, only support llava llama
from llava.model.language_model.llava_llama import ( # noqa
@@ -98,13 +99,12 @@ def build_model(self):
} # disable vision part quantization
model = AutoModelForCausalLM.from_config(self.config,
trust_remote_code=True)
+ self.vl_model = model
if not self.with_llm:
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm
- else:
- self.vl_model = model
with init_empty_vit():
vision_tower = model.get_vision_tower()
@@ -137,42 +137,43 @@ def build_model(self):
self.vision_tower = model.model.vision_tower.eval()
self.mm_projector = model.model.mm_projector.eval()
- def encode_images(self, images: torch.Tensor) -> torch.Tensor:
- """encode images."""
- image_features = self.vision_tower(images)
- image_features = self.mm_projector(image_features)
- return image_features
-
- def preprocess(
- self,
- images: List[Image]) -> Union[torch.Tensor, List[torch.Tensor]]:
- """preprocess."""
- # TODO: gpu processor
- from llava.mm_utils import process_images
- images = [x.convert('RGB') for x in images]
- image_processor = self.vision_tower.image_processor
- outputs = process_images(images, image_processor, self.config)
- return outputs
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to `super().preprocess() for spec."""
+ return super().preprocess(messages)
@torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- images = self.preprocess(images)
- if isinstance(images, list):
- images = [
- x.to(self.vision_tower.device, dtype=torch.float16)
- for x in images
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
]
- else:
- images = images.to(self.vision_tower.device, dtype=torch.float16)
-
- if type(images) is list or images.ndim == 5:
- concat_images = torch.cat([image for image in images], dim=0)
- image_features = self.encode_images(concat_images)
- split_sizes = [image.shape[0] for image in images]
- image_features = torch.split(image_features, split_sizes, dim=0)
- image_features = [x.flatten(0, 1) for x in image_features]
- else:
- image_features = self.encode_images(images)
- image_features = [x for x in image_features]
- return image_features
+ split_sizes = [x.shape[0] for x in pixel_values]
+ pixel_values = torch.cat(pixel_values, dim=0)
+ pixel_values = pixel_values.to(device=self.vision_tower.device,
+ dtype=torch.float16)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ if pixel_values.ndim == 5:
+ feats = self.encode_images(pixel_values)
+ feats = torch.split(feats, split_sizes, dim=0)
+ feats = [x.flatten(0, 1) for x in feats]
+ else:
+ feats = self.encode_images(pixel_values)
+ feats = [x for x in feats]
+ outputs.extend(feats)
+ messages.append(dict(role='forward', content=outputs))
+ return messages
diff --git a/lmdeploy/vl/model/llava.py b/lmdeploy/vl/model/llava.py
index 0b18f460cd..7ad919bef7 100644
--- a/lmdeploy/vl/model/llava.py
+++ b/lmdeploy/vl/model/llava.py
@@ -1,16 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
-# Modified from
-# https://github.com/haotian-liu/LLaVA.git
+# Modified from https://github.com/haotian-liu/LLaVA.git
+import ast
+import math
import warnings
from contextlib import contextmanager
-from typing import List, Union
+from typing import Dict, List
import torch
-from PIL.Image import Image
+from PIL import Image
from transformers import AutoConfig, AutoModelForCausalLM
from lmdeploy.utils import get_logger
-from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
+from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel
from lmdeploy.vl.model.utils import disable_logging, rewrite_ctx
logger = get_logger('lmdeploy')
@@ -23,16 +24,14 @@ def check_llava_install():
except ImportError:
raise ImportError(
'To use LlavaVLModel, please install llava by '
- 'pip install git+https://github.com/haotian-liu/LLaVA.git --no-deps' # noqa: E501
+ '`pip install git+https://github.com/haotian-liu/LLaVA.git --no-deps`' # noqa: E501
)
def _clip_vision_tower_load_model(self, **kwargs):
logger.info(f'CLIPVisionTower.load_model: {self.vision_tower_name}')
- from transformers import (CLIPImageProcessor, CLIPVisionConfig,
- CLIPVisionModel)
- self.image_processor = CLIPImageProcessor.from_pretrained(
- self.vision_tower_name)
+ from transformers import CLIPVisionConfig, CLIPVisionModel
+
config = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel._from_config(config=config)
self.vision_tower.requires_grad_(False)
@@ -53,8 +52,166 @@ def init_llava_vision_tower(config):
yield
+def select_best_resolution(original_size, possible_resolutions):
+ """Selects the best resolution from a list of possible resolutions based on
+ the original size.
+
+ Args:
+ original_size (tuple): The original size of the image in the format (width, height).
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
+
+ Returns:
+ tuple: The best fit resolution in the format (width, height).
+ """ # noqa
+ original_width, original_height = original_size
+ best_fit = None
+ max_effective_resolution = 0
+ min_wasted_resolution = float('inf')
+
+ for width, height in possible_resolutions:
+ scale = min(width / original_width, height / original_height)
+ downscaled_width, downscaled_height = int(original_width * scale), int(
+ original_height * scale)
+ effective_resolution = min(downscaled_width * downscaled_height,
+ original_width * original_height)
+ wasted_resolution = (width * height) - effective_resolution
+
+ if effective_resolution > max_effective_resolution or (
+ effective_resolution == max_effective_resolution
+ and wasted_resolution < min_wasted_resolution):
+ max_effective_resolution = effective_resolution
+ min_wasted_resolution = wasted_resolution
+ best_fit = (width, height)
+
+ return best_fit
+
+
+def resize_and_pad_image(image, target_resolution):
+ """Resize and pad an image to a target resolution while maintaining aspect
+ ratio.
+
+ Args:
+ image (PIL.Image.Image): The input image.
+ target_resolution (tuple): The target resolution (width, height) of the image.
+
+ Returns:
+ PIL.Image.Image: The resized and padded image.
+ """ # noqa
+ original_width, original_height = image.size
+ target_width, target_height = target_resolution
+
+ scale_w = target_width / original_width
+ scale_h = target_height / original_height
+
+ if scale_w < scale_h:
+ new_width = target_width
+ new_height = min(math.ceil(original_height * scale_w), target_height)
+ else:
+ new_height = target_height
+ new_width = min(math.ceil(original_width * scale_h), target_width)
+
+ # Resize the image
+ resized_image = image.resize((new_width, new_height))
+
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
+ paste_x = (target_width - new_width) // 2
+ paste_y = (target_height - new_height) // 2
+ new_image.paste(resized_image, (paste_x, paste_y))
+
+ return new_image
+
+
+def divide_to_patches(image, patch_size):
+ """Divides an image into patches of a specified size.
+
+ Args:
+ image (PIL.Image.Image): The input image.
+ patch_size (int): The size of each patch.
+
+ Returns:
+ list: A list of PIL.Image.Image objects representing the patches.
+ """
+ patches = []
+ width, height = image.size
+ for i in range(0, height, patch_size):
+ for j in range(0, width, patch_size):
+ box = (j, i, j + patch_size, i + patch_size)
+ patch = image.crop(box)
+ patches.append(patch)
+
+ return patches
+
+
+def process_anyres_image(image, processor, grid_pinpoints):
+ """Process an image with variable resolutions.
+
+ Args:
+ image (PIL.Image.Image): The input image to be processed.
+ processor: The image processor object.
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
+
+ Returns:
+ torch.Tensor: A tensor containing the processed image patches.
+ """ # noqa
+ if type(grid_pinpoints) is list:
+ possible_resolutions = grid_pinpoints
+ else:
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
+ image_padded = resize_and_pad_image(image, best_resolution)
+
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
+
+ image_original_resize = image.resize(
+ (processor.size['shortest_edge'], processor.size['shortest_edge']))
+
+ image_patches = [image_original_resize] + patches
+ image_patches = [
+ processor.preprocess(image_patch,
+ return_tensors='pt')['pixel_values'][0]
+ for image_patch in image_patches
+ ]
+ return torch.stack(image_patches, dim=0)
+
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+def process_images(images, image_processor, model_cfg):
+ image_aspect_ratio = getattr(model_cfg, 'image_aspect_ratio', None)
+ new_images = []
+ if image_aspect_ratio == 'pad':
+ for image in images:
+ image = expand2square(
+ image, tuple(int(x * 255) for x in image_processor.image_mean))
+ image = image_processor.preprocess(
+ image, return_tensors='pt')['pixel_values'][0]
+ new_images.append(image)
+ elif image_aspect_ratio == 'anyres':
+ for image in images:
+ image = process_anyres_image(image, image_processor,
+ model_cfg.image_grid_pinpoints)
+ new_images.append(image)
+ else:
+ return image_processor(images, return_tensors='pt')['pixel_values']
+ if all(x.shape == new_images[0].shape for x in new_images):
+ new_images = torch.stack(new_images, dim=0)
+ return new_images
+
+
@VISION_MODELS.register_module()
-class LlavaVisionModel(VisonModel):
+class LlavaVisionModel(LlavaHfVisionModel):
"""Llava visual model."""
@classmethod
@@ -73,9 +230,20 @@ def match(cls, config: AutoConfig):
return True
return False
+ def build_preprocessor(self):
+ from transformers import CLIPImageProcessor
+ self.image_processor = CLIPImageProcessor.from_pretrained(
+ self.hf_config.mm_vision_tower)
+ config = AutoConfig.from_pretrained(self.hf_config.mm_vision_tower)
+ image_size = config.vision_config.image_size
+ patch_size = config.vision_config.patch_size
+ self.n_token_per_image = (image_size // patch_size)**2
+ if self.hf_config.mm_vision_select_feature == 'cls_patch':
+ self.n_token_per_image += 1
+
def build_model(self):
- """build model & load weights."""
- # check llava install
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
check_llava_install()
self.arch = self.hf_config.architectures[0]
@@ -104,15 +272,13 @@ def build_model(self):
model = AutoModelForCausalLM.from_config(self.config,
trust_remote_code=True)
+ self.vl_model = model
if not self.with_llm:
# remove the LLM part from llava model.
- # Instead, Load the LLM part to turbomind engine
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm
- else:
- self.vl_model = model
# init empty vision_tower, the embedding layer in CLIPVisionModel
# can't init right under init_empty_weights
@@ -143,101 +309,113 @@ def encode_images(self, images: torch.Tensor) -> torch.Tensor:
image_features = self.mm_projector(image_features)
return image_features
- def preprocess(
- self,
- images: List[Image]) -> Union[torch.Tensor, List[torch.Tensor]]:
- """preprocess."""
- # TODO: gpu processor
- from llava.mm_utils import process_images
- images = [x.convert('RGB') for x in images]
- image_processor = self.vision_tower.image_processor
- outputs = process_images(images, image_processor, self.config)
- return outputs
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to `super().preprocess() for spec."""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ pixel_values = process_images([image], self.image_processor,
+ self.config)
+ outputs.append(
+ dict(pixel_values=pixel_values,
+ image_size=image.size,
+ image_tokens=self.n_token_per_image,
+ image_token_id=0))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
@torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
from llava.model.llava_arch import (get_anyres_image_grid_shape,
unpad_image)
- image_sizes = [x.size for x in images]
- images = self.preprocess(images)
- if isinstance(images, list):
- images = [
- x.to(device=self.vision_tower.device, dtype=torch.float16)
- for x in images
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ image_sizes = [
+ x['image_size'] for x in inputs[idx:idx + max_batch_size]
]
- else:
- images = images.to(device=self.vision_tower.device,
- dtype=torch.float16)
- if type(images) is list or images.ndim == 5:
- if type(images) is list:
- images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
- concat_images = torch.cat([image for image in images], dim=0)
- image_features = self.encode_images(concat_images)
- split_sizes = [image.shape[0] for image in images]
- image_features = torch.split(image_features, split_sizes, dim=0)
- mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type',
- 'flat')
- image_aspect_ratio = getattr(self.config, 'image_aspect_ratio',
- 'square')
- if mm_patch_merge_type == 'flat':
- image_features = [x.flatten(0, 1) for x in image_features]
- elif mm_patch_merge_type.startswith('spatial'):
- new_image_features = []
- for image_idx, image_feature in enumerate(image_features):
- if image_feature.shape[0] > 1:
- base_image_feature = image_feature[0]
- image_feature = image_feature[1:]
- height = width = self.vision_tower.num_patches_per_side
- assert height * width == base_image_feature.shape[0]
- if image_aspect_ratio == 'anyres':
- num_patch_width, num_patch_height = \
- get_anyres_image_grid_shape(
- image_sizes[image_idx],
- self.config.image_grid_pinpoints,
- self.vision_tower.config.image_size)
- image_feature = image_feature.view(
- num_patch_height, num_patch_width, height,
- width, -1)
- else:
- raise NotImplementedError
- if 'unpad' in mm_patch_merge_type:
- image_feature = image_feature.permute(
- 4, 0, 2, 1, 3).contiguous()
- image_feature = image_feature.flatten(1,
- 2).flatten(
- 2, 3)
- image_feature = unpad_image(
- image_feature, image_sizes[image_idx])
- image_feature = torch.cat((
- image_feature,
- self.model.image_newline[:, None, None].expand(
- *image_feature.shape[:-1], 1).to(
- image_feature.device)),
- dim=-1)
- image_feature = image_feature.flatten(1,
- 2).transpose(
- 0, 1)
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ if pixel_values[0].ndim == 5:
+ split_sizes = [x.shape[1] for x in pixel_values]
+ pixel_values = torch.cat([x for x in pixel_values], dim=1)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ pixel_values = pixel_values.squeeze(0)
+ pixel_values = pixel_values.to(device=self.vision_tower.device,
+ dtype=torch.float16)
+ feats = self.encode_images(pixel_values)
+ feats = torch.split(feats, split_sizes, dim=0)
+ mm_patch_merge_type = getattr(self.config,
+ 'mm_patch_merge_type', 'flat')
+ image_aspect_ratio = getattr(self.config, 'image_aspect_ratio',
+ 'square')
+ if mm_patch_merge_type == 'flat':
+ outputs.expand([x.flatten(0, 1) for x in feats])
+ elif mm_patch_merge_type.startswith('spatial'):
+ for img_idx, feat in enumerate(feats):
+ if feat.shape[0] > 1:
+ base_feat = feat[0]
+ feat = feat[1:]
+ height = self.vision_tower.num_patches_per_side
+ width = self.vision_tower.num_patches_per_side
+ assert height * width == base_feat.shape[0]
+ if image_aspect_ratio == 'anyres':
+ num_patch_width, num_patch_height = \
+ get_anyres_image_grid_shape(
+ image_sizes[img_idx],
+ self.config.image_grid_pinpoints,
+ self.vision_tower.config.image_size)
+ feat = feat.view(num_patch_height,
+ num_patch_width, height,
+ width, -1)
+ else:
+ raise NotImplementedError
+ if 'unpad' in mm_patch_merge_type:
+ feat = feat.permute(4, 0, 2, 1, 3).contiguous()
+ feat = feat.flatten(1, 2).flatten(2, 3)
+ feat = unpad_image(feat, image_sizes[img_idx])
+ feat = torch.cat(
+ (feat, self.model.
+ image_newline[:, None, None].expand(
+ *feat.shape[:-1], 1).to(feat.device)),
+ dim=-1)
+ feat = feat.flatten(1, 2).transpose(0, 1)
+ else:
+ feat = feat.permute(0, 2, 1, 3, 4).contiguous()
+ feat = feat.flatten(0, 3)
+ feat = torch.cat((base_feat, feat), dim=0)
else:
- image_feature = image_feature.permute(
- 0, 2, 1, 3, 4).contiguous()
- image_feature = image_feature.flatten(0, 3)
- image_feature = torch.cat(
- (base_image_feature, image_feature), dim=0)
- else:
- image_feature = image_feature[0]
- if 'unpad' in mm_patch_merge_type:
- image_feature = torch.cat(
- (image_feature,
- self.model.image_newline[None].to(
- image_feature.device)),
- dim=0)
- new_image_features.append(image_feature)
- image_features = new_image_features
+ feat = feat[0]
+ if 'unpad' in mm_patch_merge_type:
+ feat = torch.cat(
+ (feat, self.model.image_newline[None].to(
+ feat.device)),
+ dim=0)
+ outputs.append(feat)
+ else:
+ raise ValueError('Unexpected mm_patch_merge_type: '
+ f'{self.config.mm_patch_merge_type}')
else:
- raise ValueError('Unexpected mm_patch_merge_type: '
- f'{self.config.mm_patch_merge_type}')
- else:
- image_features = self.encode_images(images)
- image_features = [x for x in image_features]
- return image_features
+ pixel_values = torch.cat(pixel_values, dim=0)
+ pixel_values = pixel_values.to(device=self.vision_tower.device,
+ dtype=torch.float16)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ feats = self.encode_images(pixel_values)
+ outputs.extend([x for x in feats])
+ messages.append(dict(role='forward', content=outputs))
+ return messages
diff --git a/lmdeploy/vl/model/llava_hf.py b/lmdeploy/vl/model/llava_hf.py
index 31be101ae8..c4e3c90bfb 100644
--- a/lmdeploy/vl/model/llava_hf.py
+++ b/lmdeploy/vl/model/llava_hf.py
@@ -1,15 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
-
import warnings
-from typing import List
+from typing import Dict, List
import torch
-from PIL.Image import Image
from transformers import AutoProcessor
+from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import disable_logging
+logger = get_logger('lmdeploy')
+
@VISION_MODELS.register_module()
class LlavaHfVisionModel(VisonModel):
@@ -17,19 +18,31 @@ class LlavaHfVisionModel(VisonModel):
_arch = 'LlavaForConditionalGeneration'
+ def build_preprocessor(self):
+ processor = AutoProcessor.from_pretrained(self.model_path,
+ trust_remote_code=True)
+ if hasattr(processor, 'tokenizer'):
+ del processor.tokenizer
+ processor.prtokenizer = None
+ self.processor = processor.image_processor
+ image_size = self.hf_config.vision_config.image_size
+ patch_size = self.hf_config.vision_config.patch_size
+ self.n_token_per_image = (image_size // patch_size)**2
+ if self.hf_config.vision_feature_select_strategy == 'full':
+ self.n_token_per_image += 1
+
def build_model(self):
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
with init_empty_weights(), warnings.catch_warnings():
warnings.simplefilter('ignore')
from transformers import LlavaForConditionalGeneration
model = LlavaForConditionalGeneration._from_config(self.hf_config)
+ self.vl_model = model
if not self.with_llm:
del model.language_model
- for key in ['language_model']:
- setattr(model, key, None)
- else:
- self.vl_model = model
# fix for llava-hf/llava-interleave-qwen-7b-hf
setattr(model.config, 'tie_word_embeddings', False)
@@ -45,35 +58,97 @@ def build_model(self):
dtype=torch.half)
model.eval()
self.model = model
- # processor
- processor = AutoProcessor.from_pretrained(self.model_path,
- trust_remote_code=True)
- if hasattr(processor, 'tokenizer'):
- del processor.tokenizer
- processor.prtokenizer = None
- self.processor = processor.image_processor
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refers to `super.preprocess() for spec."""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ pixel_values = self.processor(
+ image, return_tensors='pt',
+ input_data_format='channels_last').pixel_values
+ outputs.append(
+ dict(pixel_values=pixel_values,
+ image_size=image.size,
+ image_tokens=self.n_token_per_image,
+ image_token_id=0))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
@torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- pixel_values = self.processor(
- images, return_tensors='pt',
- input_data_format='channels_last')['pixel_values']
- pixel_values = pixel_values.to(device=self.model.device,
- dtype=self.model.dtype)
- image_outputs = self.model.vision_tower.forward(
- pixel_values, output_hidden_states=True)
- image_features = image_outputs.hidden_states[
- self.hf_config.vision_feature_layer]
- if self.hf_config.vision_feature_select_strategy == 'default':
- image_features = image_features[:, 1:]
- elif self.hf_config.vision_feature_select_strategy == 'full':
- image_features = image_features
- else:
- raise ValueError(
- 'Unexpected select feature strategy: '
- f'{self.hf_config.vision_feature_select_strategy}')
- image_features = self.model.multi_modal_projector(image_features)
- outputs = torch.split(image_features, 1, dim=0)
- outputs = [x.squeeze() for x in outputs]
- return outputs
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ pixel_values = torch.cat(pixel_values, dim=0)
+ pixel_values = pixel_values.to(device=self.model.device,
+ dtype=self.model.dtype)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ image_outputs = self.model.vision_tower.forward(
+ pixel_values, output_hidden_states=True)
+ image_features = image_outputs.hidden_states[
+ self.hf_config.vision_feature_layer]
+ if self.hf_config.vision_feature_select_strategy == 'default':
+ image_features = image_features[:, 1:]
+ elif self.hf_config.vision_feature_select_strategy == 'full':
+ image_features = image_features
+ else:
+ raise ValueError(
+ 'Unexpected select feature strategy: '
+ f'{self.hf_config.vision_feature_select_strategy}')
+ image_features = self.model.multi_modal_projector(image_features)
+ image_features = torch.split(image_features, 1, dim=0)
+ outputs.extend([x.squeeze() for x in image_features])
+ messages.append(dict(role='forward', content=outputs))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
+ continue
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ content = [
+ item['text'] for item in message['content']
+ if item['type'] == 'text'
+ ]
+ prompt = (IMAGE_TOKEN + '\n') * n_images + content[0]
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/llava_next.py b/lmdeploy/vl/model/llava_next.py
index 9223ebea4f..d355a48d60 100644
--- a/lmdeploy/vl/model/llava_next.py
+++ b/lmdeploy/vl/model/llava_next.py
@@ -1,46 +1,51 @@
# Copyright (c) OpenMMLab. All rights reserved.
-
+import itertools
import warnings
-from typing import List
+from typing import Dict, List
import torch
-from PIL.Image import Image
-from transformers import AutoProcessor
-from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
+from lmdeploy.utils import get_logger
+from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel
from lmdeploy.vl.model.utils import disable_logging
+logger = get_logger('lmdeploy')
+
@VISION_MODELS.register_module()
-class LlavaNextVisionModel(VisonModel):
+class LlavaNextVisionModel(LlavaHfVisionModel):
"""Llava hf vision model."""
_arch = 'LlavaNextForConditionalGeneration'
- def build_model(self):
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
- from accelerate.utils import get_balanced_memory, infer_auto_device_map
-
+ def build_preprocessor(self):
+ super().build_preprocessor()
+ # build the model with empty weights. The model will be used in
+ # `preprocess` to get the image token number
+ from accelerate import init_empty_weights
with init_empty_weights(), warnings.catch_warnings():
warnings.simplefilter('ignore')
from transformers import LlavaNextForConditionalGeneration
- model = LlavaNextForConditionalGeneration._from_config(
+ self.model = LlavaNextForConditionalGeneration._from_config(
self.hf_config)
+ self.vl_model = self.model
if not self.with_llm:
- del model.language_model
- for key in ['language_model']:
- setattr(model, key, None)
- else:
- self.vl_model = model
+ del self.model.language_model
+
+ def build_model(self):
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
+ from accelerate import load_checkpoint_and_dispatch
+ from accelerate.utils import get_balanced_memory, infer_auto_device_map
no_split_module_classes = ['CLIPEncoderLayer']
max_memory = get_balanced_memory(
- model,
+ self.model,
max_memory=self.max_memory,
dtype=torch.half,
no_split_module_classes=no_split_module_classes)
device_map = infer_auto_device_map(
- model,
+ self.model,
no_split_module_classes=no_split_module_classes,
max_memory=max_memory,
dtype=torch.half)
@@ -55,75 +60,128 @@ def build_model(self):
with disable_logging():
load_checkpoint_and_dispatch(
- model=model,
+ model=self.model,
checkpoint=self.model_path,
device_map=device_map if not self.with_llm else {'': 'cpu'},
no_split_module_classes=no_split_module_classes,
dtype=torch.half)
- model.eval()
- self.model = model
- # processor
- processor = AutoProcessor.from_pretrained(self.model_path,
- trust_remote_code=True)
- if hasattr(processor, 'tokenizer'):
- del processor.tokenizer
- processor.prtokenizer = None
- self.processor = processor.image_processor
+ self.model.eval()
- @torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refers to the spec of `super.preprocess()"""
from transformers.models.llava_next.modeling_llava_next import \
image_size_to_num_patches
- """forward."""
- processed_inputs = self.processor(images,
- return_tensors='pt',
- input_data_format='channels_last')
- pixel_values = processed_inputs['pixel_values'].to(
- device=self.model.device, dtype=self.model.dtype)
- image_sizes = processed_inputs['image_sizes'].to(
- device=self.model.device, dtype=self.model.dtype)
- # ! infer image_num_patches from image_sizes
- image_num_patches = [
- image_size_to_num_patches(
- image_size=imsize,
- grid_pinpoints=self.hf_config.image_grid_pinpoints,
- patch_size=self.hf_config.vision_config.image_size,
- ) for imsize in image_sizes
- ]
- # figure out if pixel_values is concatenated or stacked
- if pixel_values.dim() == 5:
- # stacking when input is
- # (batch_size, num_patches, num_channels, height, width)
- _pixel_values_list = [
- pix_val[:num_patch]
- for pix_val, num_patch in zip(pixel_values, image_num_patches)
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ result = self.processor(image,
+ return_tensors='pt',
+ input_data_format='channels_last')
+ # ! infer image_num_patches from image_sizes
+ image_num_patches = [
+ image_size_to_num_patches(
+ image_size=imsize,
+ grid_pinpoints=self.hf_config.image_grid_pinpoints,
+ patch_size=self.hf_config.vision_config.image_size,
+ ) for imsize in result['image_sizes']
]
- pixel_values = torch.cat(_pixel_values_list, dim=0)
- elif pixel_values.dim() != 4:
- # otherwise has to be stacked from list of
- # (num_patches, num_channels, height, width)
- raise ValueError(f'pixel_values of shape {pixel_values.shape}, '
- 'expect to be of 4 or 5 dimensions')
- image_outputs = self.model.vision_tower.forward(
- pixel_values, output_hidden_states=True)
- image_features = image_outputs.hidden_states[
- self.hf_config.vision_feature_layer]
- if self.hf_config.vision_feature_select_strategy == 'default':
- image_features = image_features[:, 1:]
- elif self.hf_config.vision_feature_select_strategy == 'full':
- image_features = image_features
- else:
- raise ValueError(
- 'Unexpected select feature strategy: '
- f'{self.hf_config.vision_feature_select_strategy}')
- image_features = self.model.multi_modal_projector(image_features)
- image_features = torch.split(image_features, image_num_patches, dim=0)
- image_features, feature_lens = self.model.pack_image_features(
- image_features,
- image_sizes,
- image_newline=self.model.image_newline,
- )
- outputs = torch.split(image_features,
- feature_lens.cpu().numpy().tolist(),
- dim=0)
- return outputs
+
+ hidden_size = self.hf_config.text_config.hidden_size
+ fake_image_features = torch.zeros(
+ [image_num_patches[0], self.n_token_per_image, hidden_size])
+ image_sizes = result['image_sizes']
+ image_newline = torch.randn(self.hf_config.text_config.hidden_size)
+ strategy = self.hf_config.vision_feature_select_strategy
+ _, image_tokens = self.model.pack_image_features(
+ [fake_image_features],
+ image_sizes,
+ vision_feature_select_strategy=strategy,
+ image_newline=image_newline)
+ result.update(
+ dict(image_size=image.size,
+ image_patches=image_num_patches,
+ image_tokens=image_tokens,
+ image_token_id=0))
+ outputs.append(result)
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
+
+ @torch.no_grad()
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ pixel_values = [
+ x['pixel_values'].to(device=self.model.device,
+ dtype=self.model.dtype)
+ for x in inputs[idx:idx + max_batch_size]
+ ]
+ pixel_values = torch.cat(pixel_values, dim=0)
+ image_sizes = [
+ x['image_sizes'].to(device=self.model.device,
+ dtype=self.model.dtype)
+ for x in inputs[idx:idx + max_batch_size]
+ ]
+ image_sizes = torch.cat(image_sizes, dim=0)
+ image_num_patches = [
+ x['num_patch'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ image_num_patches = list(itertools.chain(*image_num_patches))
+ # figure out if pixel_values is concatenated or stacked
+ if pixel_values.dim() == 5:
+ # stacking when input is
+ # (batch_size, num_patches, num_channels, height, width)
+ _pixel_values_list = [
+ pix_val[:num_patch] for pix_val, num_patch in zip(
+ pixel_values, image_num_patches)
+ ]
+ pixel_values = torch.cat(_pixel_values_list, dim=0)
+ elif pixel_values.dim() != 4:
+ # otherwise has to be stacked from list of
+ # (num_patches, num_channels, height, width)
+ raise ValueError(
+ f'pixel_values of shape {pixel_values.shape}, '
+ 'expect to be of 4 or 5 dimensions')
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ image_outputs = self.model.vision_tower.forward(
+ pixel_values, output_hidden_states=True)
+ image_features = image_outputs.hidden_states[
+ self.hf_config.vision_feature_layer]
+ strategy = self.hf_config.vision_feature_select_strategy
+ if strategy == 'default':
+ image_features = image_features[:, 1:]
+ elif strategy == 'full':
+ image_features = image_features
+ else:
+ raise ValueError('Unexpected select feature strategy: '
+ f'{strategy}')
+ image_features = self.model.multi_modal_projector(image_features)
+ image_features = torch.split(image_features,
+ image_num_patches,
+ dim=0)
+ image_features, feature_lens = self.model.pack_image_features(
+ image_features,
+ image_sizes,
+ vision_feature_select_strategy=strategy,
+ image_newline=self.model.image_newline,
+ )
+ image_features = torch.split(image_features,
+ feature_lens.cpu().numpy().tolist(),
+ dim=0)
+ outputs.extend(image_features)
+ messages.append(dict(role='forward', content=outputs))
+ return messages
diff --git a/lmdeploy/vl/model/mini_gemeni.py b/lmdeploy/vl/model/mini_gemeni.py
index 0565daeba5..eca70aca51 100644
--- a/lmdeploy/vl/model/mini_gemeni.py
+++ b/lmdeploy/vl/model/mini_gemeni.py
@@ -3,16 +3,18 @@
import os.path as osp
import warnings
from contextlib import contextmanager
-from typing import List
+from typing import Dict, List
import torch
-from PIL.Image import Image
+from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import (add_device_hook, disable_logging,
disable_transformers_logging,
hack_import_with)
+logger = get_logger('lmdeploy')
+
def check_mini_gemini_install():
"""check mini gemini install."""
@@ -22,8 +24,8 @@ def check_mini_gemini_install():
except ImportError:
raise ImportError(
'To use MiniGeminiVisionModel, please install minigemini by '
- 'pip install git+https://github.com/dvlab-research/MGM.git'
- ' --no-deps')
+ '`pip install git+https://github.com/dvlab-research/MGM.git'
+ ' --no-deps`')
def _build_vision_tower(vision_tower_cfg, **kwargs):
@@ -169,7 +171,15 @@ class MiniGeminiVisionModel(VisonModel):
_arch = ['MiniGeminiLlamaForCausalLM', 'MGMLlamaForCausalLM']
+ def build_preprocessor(self):
+ # pytorch engine will not support mini-gemini. Therefore, in order to
+ # reuse the previous code as much as possible, we do not extract image
+ # preprocessor from `build_model` function.
+ pass
+
def build_model(self):
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
check_mini_gemini_install()
# empty init
from accelerate import init_empty_weights
@@ -193,13 +203,12 @@ def build_model(self):
vision_tower.load_model()
vision_tower_aux = model.get_vision_tower_aux()
vision_tower_aux.load_model()
+ self.vl_model = model
if not self.with_llm:
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm
- else:
- self.vl_model = model
from accelerate.utils import get_balanced_memory, infer_auto_device_map
max_memory = get_balanced_memory(
@@ -246,11 +255,35 @@ def build_model(self):
self.image_processor = image_processor
self.process_images = process_images
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ return messages
+
@torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- outputs = [x.convert('RGB') for x in images]
- image_tensor = self.process_images(outputs, self.image_processor,
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ images = []
+ for message in messages:
+ if not isinstance(message['content'], List):
+ continue
+ _ = [
+ x['image'] for x in message['content'] if x['type'] == 'image'
+ ]
+ assert len(_) == 1, f'MiniGeminiLlama accepts ONE input ' \
+ f'image, but got {len(images)} images'
+ images.extend(_)
+
+ image_tensor = self.process_images(images, self.image_processor,
self.model.config)
image_grid = getattr(self.model.config, 'image_grid', 1)
if hasattr(self.model.config, 'image_size_aux'):
@@ -301,15 +334,47 @@ def forward(self, images: List[Image]) -> List[torch.Tensor]:
image.to(self.model.device, dtype=torch.float16)
for image in image_tensor_aux
]
+ logger.info(f'vision forward bs: {len(image_tensor)}')
else:
image_tensor = image_tensor.to(self.model.device,
dtype=torch.float16)
image_tensor_aux = image_tensor_aux.to(self.model.device,
dtype=torch.float16)
-
+ logger.info(f'vision forward shape: {image_tensor.shape}')
images_embeds = self.model.encode_images(image_tensor,
image_tensor_aux)
outputs = torch.split(images_embeds, 1, dim=0)
outputs = [x.squeeze() for x in outputs]
- return outputs
+ messages.append(dict(role='forward', cotent=outputs))
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
+ continue
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ content = [
+ item['text'] for item in message['content']
+ if item['type'] == 'text'
+ ]
+ prompt = (IMAGE_TOKEN + '\n') * n_images + content[0]
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ assert 0, 'cogvlm is not supported by pytorch engine'
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/minicpmv.py b/lmdeploy/vl/model/minicpmv.py
index 4e30190c1d..6b0c5f1508 100644
--- a/lmdeploy/vl/model/minicpmv.py
+++ b/lmdeploy/vl/model/minicpmv.py
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import itertools
import warnings
from typing import Dict, List
import torch
from PIL.Image import Image
-from transformers import AutoModelForCausalLM
+from transformers import AutoConfig, AutoModelForCausalLM
from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
@@ -19,8 +20,33 @@ class MiniCPMVModel(VisonModel):
_arch = 'MiniCPMV'
+ def __init__(self,
+ model_path: str,
+ with_llm: bool = False,
+ max_memory: Dict[int, int] = None,
+ hf_config: AutoConfig = None,
+ backend: str = ''):
+ super().__init__(model_path, with_llm, max_memory, hf_config, backend)
+ if not hasattr(self.hf_config, 'version'):
+ raise ValueError('Can not find `version` in config.json. '
+ 'Please checkout the latest model')
+ version = str(self.hf_config.version)
+ if version not in ['2.5', '2.6']:
+ raise ValueError(
+ f'Only support v2.5 and v2.6, but got version {version}')
+ self.version = version
+
+ def build_preprocessor(self):
+ from transformers import AutoProcessor
+ self.processor = AutoProcessor.from_pretrained(self.model_path,
+ trust_remote_code=True)
+ self.image_processor = self.processor.image_processor
+ self._preprocess_func = (self._preprocess_v2_5 if self.version == '2.5'
+ else self._preprocess_v2_6)
+
def build_model(self):
- """build model & load weights."""
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights
with init_empty_weights(), warnings.catch_warnings():
warnings.simplefilter('ignore')
@@ -29,10 +55,9 @@ def build_model(self):
config.quantization_config = {} # disable vision part quantization
model = AutoModelForCausalLM.from_config(config,
trust_remote_code=True)
+ self.vl_model = model
if not self.with_llm:
del model.llm
- else:
- self.vl_model = model
from accelerate import load_checkpoint_and_dispatch
with disable_logging():
@@ -50,46 +75,11 @@ def build_model(self):
device=model.resampler.proj.device)
self.config = config
self.model = model.eval()
- self.init_forward_func()
-
- def init_forward_func(self):
- if not hasattr(self.config, 'version'):
- msg = 'LMDeploy only support `MiniCPM-V-2_6` and '\
- '`MiniCPM-Llama3-V-2_5`.\nCan not find `version` in config, ' \
- 'please consider update the huggingface model.'
- logger.warn(msg)
-
- self._forward_func = self._forward_v2_5
- if hasattr(self.config, 'version'):
- version = str(self.config.version)
- if version == '2.6':
- self._forward_func = self._forward_v2_6
-
- if self._forward_func == self._forward_v2_5:
- logger.info('using _forward_v2_5')
- if not hasattr(self.model, 'slice_image'):
- # adapt new code commit 287e3f85 (MiniCPM-Llama3-V-2_5)
- from transformers import AutoProcessor
- processor = AutoProcessor.from_pretrained(
- self.model_path, trust_remote_code=True)
- self.model.slice_image = processor.image_processor.slice_image
-
- def _reshape_by_patch(x):
- out = x.cpu().numpy()
- out = processor.image_processor.reshape_by_patch(out)
- return torch.from_numpy(out).to(device=x.device)
-
- self.model.reshape_by_patch = _reshape_by_patch
-
- if self._forward_func == self._forward_v2_6:
- logger.info('using _forward_v2_6')
- from transformers import AutoProcessor
- self.model.processor = AutoProcessor.from_pretrained(
- self.model_path, trust_remote_code=True)
def _get_slice_image(self, image: Image):
slice_images = []
- source_image, patches, best_grid = self.model.slice_image(image)
+ source_image, patches, best_grid = self.image_processor.slice_image(
+ image)
slice_images.append(source_image)
if len(patches) > 0:
for i in range(len(patches)):
@@ -103,114 +93,198 @@ def _reshape_by_patch(self, slice_images):
for slice_image in slice_images:
slice_image = self.model.transform(slice_image)
H, W = slice_image.shape[1:]
- patches.append(self.model.reshape_by_patch(slice_image))
+ slice_image = slice_image.numpy()
+ slice_image = self.image_processor.reshape_by_patch(slice_image)
+ slice_image = torch.from_numpy(slice_image)
+ patches.append(slice_image)
H //= self.config.patch_size
W //= self.config.patch_size
tgt_sizes.append(torch.Tensor([H, W]).type(torch.int32))
return patches, tgt_sizes
- def _forward_v2_5(self, images: List[Image], params: List[Dict] = None):
- """forward for MiniCPM-Llama3-V-2_5."""
- patches = []
- tgt_sizes = []
- best_grids = []
- num_patches = []
- for image in images:
- slice_images, best_grid = self._get_slice_image(image)
- _patches, _tgt_sizes = self._reshape_by_patch(slice_images)
- num_patches.append(len(_patches))
- patches.extend(_patches)
- tgt_sizes.extend(_tgt_sizes)
- best_grids.append(best_grid)
-
- patches = [
- x.to(dtype=torch.half, device=self.model.device) for x in patches
- ]
- patches = [x.flatten(end_dim=1).permute(1, 0) for x in patches]
- tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
- max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
- all_pixel_values = torch.nn.utils.rnn.pad_sequence(patches,
- batch_first=True,
- padding_value=0.0)
- B, L, _ = all_pixel_values.shape
- all_pixel_values = all_pixel_values.permute(0, 2,
- 1).reshape(B, 3, -1, L)
- patch_attn_mask = torch.zeros((B, 1, max_patches),
- dtype=torch.bool,
- device=self.model.device)
- for i in range(B):
- patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
- vision_embedding = self.model.vpm(
- all_pixel_values.type(torch.half),
- patch_attention_mask=patch_attn_mask).last_hidden_state
- vision_embedding = self.model.resampler(vision_embedding, tgt_sizes)
- vision_embedding = torch.split(vision_embedding, num_patches, 0)
+ def _preprocess_v2_5(self, image: Image, params: Dict = None) -> Dict:
+ """image preprocessing for MiniCPM-Llama3-V-2_5."""
+ slice_images, best_grid = self._get_slice_image(image)
+ # pixel_values, tgt_sizes are list of torch tensors
+ pixel_values, tgt_sizes = self._reshape_by_patch(slice_images)
+ num_patches = len(pixel_values)
+ return dict(
+ pixel_values=pixel_values, # a list
+ tgt_sizes=tgt_sizes, # a list
+ best_grid=best_grid,
+ num_patches=num_patches,
+ image_tokens=1,
+ image_token_id=0)
+
+ def _preprocess_v2_6(self, image: Image, params: Dict = None) -> Dict:
+ """image preprocessing for MiniCPM-V-2_6."""
+ max_slice_nums = self.image_processor.max_slice_nums
+ use_image_id = self.image_processor.use_image_id
+ max_slice_nums = params.get('max_slice_nums', max_slice_nums)
+ use_image_id = params.get('use_image_id', use_image_id)
+ outputs = self.image_processor(image, max_slice_nums=max_slice_nums)
+ pixel_values = outputs['pixel_values'][0]
+ num_patches = len(pixel_values)
+ pixel_values = [torch.as_tensor(x) for x in pixel_values]
+ tgt_sizes = outputs['tgt_sizes'][0]
+ tgt_sizes = [torch.as_tensor(x) for x in tgt_sizes]
+ grid = self.image_processor.get_sliced_grid(
+ image_size=image.size, max_slice_nums=max_slice_nums)
+ return dict(
+ pixel_values=pixel_values, # a list
+ tgt_sizes=tgt_sizes, # a list
+ best_grid=grid,
+ num_patches=num_patches,
+ image_tokens=1,
+ image_token_id=0,
+ use_image_id=use_image_id)
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to `super().preprocess() for spec."""
outputs = []
- for embeddings, grid in zip(vision_embedding, best_grids):
- embeddings = embeddings.cpu() # n x d x h
- outputs.append(dict(embeddings=embeddings, grid=grid))
+ for i, message in enumerate(messages):
+ if message['role'] != 'user' or not isinstance(
+ message['content'], List):
+ continue
+ for item in message['content']:
+ if item['type'] == 'image':
+ image = item['image'].convert('RGB')
+ params = {
+ k: v
+ for k, v in item.items() if k not in {'type', 'image'}
+ }
+ result = self._preprocess_func(image, params)
+ outputs.append(result)
+ messages[i].update(dict(preprocess=outputs))
+ return messages
- return outputs
+ @torch.no_grad()
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
- def _forward_v2_6(self, images: List[Image], params: List[Dict] = None):
- """forward for MiniCPM-V-2_6."""
- patches = []
- tgt_sizes = []
- best_grids = []
- num_patches = []
- max_slice_nums = self.model.processor.image_processor.max_slice_nums
- use_image_id = self.model.processor.image_processor.use_image_id
- for image, param in zip(images, params):
- max_slice_nums = param.get('max_slice_nums', max_slice_nums)
- use_image_id = param.get('use_image_id', use_image_id)
- outputs = self.model.processor.image_processor(
- image, max_slice_nums=max_slice_nums)
- patches.extend(outputs['pixel_values'][0])
- num_patches.append(len(outputs['pixel_values'][0]))
- tgt_sizes.extend(outputs['tgt_sizes'][0])
- grid = self.model.processor.image_processor.get_sliced_grid(
- image_size=image.size, max_slice_nums=max_slice_nums)
- best_grids.append(grid)
-
- patches = [
- torch.as_tensor(x).to(dtype=torch.half, device=self.model.device)
- for x in patches
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ # collect preprocess results into a list
+ inputs = []
+ inputs = [
+ x['preprocess'] for x in messages if 'preprocess' in x.keys()
]
- patches = [x.flatten(end_dim=1).permute(1, 0) for x in patches]
- tgt_sizes = [torch.as_tensor(x) for x in tgt_sizes]
- tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
- max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
- all_pixel_values = torch.nn.utils.rnn.pad_sequence(patches,
+ # flatten the list
+ inputs = list(itertools.chain(*inputs))
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ tgt_sizes = [
+ x['tgt_sizes'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ num_patches = [
+ x['num_patches'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ # flatten the list
+ tgt_sizes = list(itertools.chain(*tgt_sizes))
+ pixel_values = list(itertools.chain(*pixel_values))
+ pixel_values = [
+ x.to(dtype=torch.half, device=self.model.device)
+ for x in pixel_values
+ ]
+ pixel_values = [
+ x.flatten(end_dim=1).permute(1, 0) for x in pixel_values
+ ]
+ pixel_values = torch.nn.utils.rnn.pad_sequence(pixel_values,
batch_first=True,
padding_value=0.0)
- B, L, _ = all_pixel_values.shape
- all_pixel_values = all_pixel_values.permute(0, 2,
- 1).reshape(B, 3, -1, L)
- patch_attn_mask = torch.zeros((B, 1, max_patches),
- dtype=torch.bool,
- device=self.model.device)
- for i in range(B):
- patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
- vision_embedding = self.model.vpm(
- all_pixel_values.type(torch.half),
- patch_attention_mask=patch_attn_mask,
- tgt_sizes=tgt_sizes).last_hidden_state
- vision_embedding = self.model.resampler(vision_embedding, tgt_sizes)
- vision_embedding = torch.split(vision_embedding, num_patches, 0)
- outputs = []
- for embeddings, grid in zip(vision_embedding, best_grids):
- embeddings = embeddings.cpu() # n x d x h
- outputs.append(
- dict(embeddings=embeddings,
- grid=grid,
- use_image_id=use_image_id))
+ B, L, _ = pixel_values.shape
+ pixel_values = pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
+ tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
+ max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
+ patch_attn_mask = torch.zeros((B, 1, max_patches),
+ dtype=torch.bool,
+ device=self.model.device)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ if self.version == '2.5':
+ for j in range(B):
+ patch_attn_mask[j, :tgt_sizes[j][0] *
+ tgt_sizes[j][1]] = True
+ embeddings = self.model.vpm(
+ pixel_values.type(torch.half),
+ patch_attention_mask=patch_attn_mask).last_hidden_state
+ else:
+ for j in range(B):
+ patch_attn_mask[j, 0, :tgt_sizes[j][0] *
+ tgt_sizes[j][1]] = True
+ embeddings = self.model.vpm(
+ pixel_values.type(torch.half),
+ patch_attention_mask=patch_attn_mask,
+ tgt_sizes=tgt_sizes).last_hidden_state
- return outputs
+ embeddings = self.model.resampler(embeddings, tgt_sizes)
+ embeddings = torch.split(embeddings, num_patches, 0)
+ for embedding in embeddings:
+ embedding = embedding.split(1, dim=0)
+ outputs.extend([x.squeeze() for x in embedding])
+ messages.append(dict(role='forward', content=outputs))
+ return messages
- @torch.no_grad()
- def forward(self,
- images: List[Image],
- params: List[Dict] = None) -> List[torch.Tensor]:
- """forward."""
- images = [x.convert('RGB') for x in images]
- return self._forward_func(images, params)
+ def proc_messages(self, messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ idx = 0
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ if 'preprocess' not in message.keys():
+ continue
+ prompts = []
+ for x in message['preprocess']:
+ prompt = f'{IMAGE_TOKEN}'
+ if x.get('use_image_id', False):
+ prompt = f'{idx}' + prompt
+ idx += 1
+ grid = x['best_grid']
+ if grid is not None:
+ if self.version == '2.5':
+ slice = '\n'.join(
+ [f'{IMAGE_TOKEN}' * grid[0]] *
+ grid[1])
+ prompt = f'{prompt}{slice}\n'
+ elif self.version == '2.6':
+ slice = '\n'.join(
+ [f'{IMAGE_TOKEN}' * grid[0]] *
+ grid[1])
+ prompt = prompt + slice
+ prompt += '\n'
+ else:
+ prompt = (prompt +
+ '\n' if self.version == '2.6' else prompt)
+ prompts.append(prompt)
+ content = [
+ x['text'] for x in message['content'] if x['type'] == 'text'
+ ]
+ prompt = ''.join(prompts) + content[0]
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/mllama.py b/lmdeploy/vl/model/mllama.py
index db0a0e9cbf..0cae71cd6c 100644
--- a/lmdeploy/vl/model/mllama.py
+++ b/lmdeploy/vl/model/mllama.py
@@ -2,192 +2,10 @@
from typing import Dict, List
-import torch
-import torch.nn.functional as F
-from PIL.Image import Image
-from transformers.modeling_outputs import BaseModelOutput
-from transformers.models.mllama.modeling_mllama import MllamaPreTrainedModel
-
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
-from lmdeploy.vl.model.utils import disable_logging
-
-
-class MllamaVisionModelPatch(MllamaPreTrainedModel):
-
- def apply_class_embedding(self,
- hidden_state: torch.Tensor) -> torch.Tensor:
- batch_size, _, hidden_size = hidden_state.shape
- class_embedding = self.class_embedding.expand(batch_size, 1,
- hidden_size)
- class_embedding = class_embedding.to(hidden_state.device)
- hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
- return hidden_state
-
- def forward(
- self,
- pixel_values: torch.Tensor,
- aspect_ratio_ids: torch.Tensor,
- aspect_ratio_mask: torch.Tensor,
- output_attentions: bool = None,
- output_hidden_states: bool = None,
- return_dict: bool = None,
- ):
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa
- output_hidden_states = (output_hidden_states
- if output_hidden_states is not None else
- self.config.output_hidden_states)
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict # noqa
-
- batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape # noqa
-
- pixel_values = pixel_values.reshape(
- batch_size * num_concurrent_media * num_tiles, num_channels,
- height, width)
- aspect_ratio_ids = aspect_ratio_ids.reshape(
- batch_size * num_concurrent_media, -1)
-
- # Patch embedding
- patch_embeds = self.patch_embedding(
- pixel_values.to(self.dtype).to(self.device))
- hidden_state = patch_embeds.flatten(2).transpose(1, 2)
-
- # Tile embeddings
- _, num_patches, dim = hidden_state.shape
- hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
- num_tiles, -1, dim)
- hidden_state = self.pre_tile_positional_embedding(
- hidden_state, aspect_ratio_ids)
-
- # Add cls token
- hidden_state = hidden_state.reshape(
- batch_size * num_concurrent_media * num_tiles, num_patches, dim)
- hidden_state = self.apply_class_embedding(hidden_state)
- num_patches += 1
-
- # Position embeddings
- hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
- num_tiles, num_patches, dim)
- hidden_state = self.gated_positional_embedding(hidden_state,
- aspect_ratio_ids)
-
- hidden_state = self.layernorm_pre(hidden_state)
-
- # Compute the number of tokens to pad
- num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
- # Compute padding tuple for pad function
- padding = (
- 0, 0, 0, num_padding_patches
- ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
- # Pad the tensor
- hidden_state = F.pad(hidden_state, padding, mode='constant', value=0)
- slice_index = -num_padding_patches if num_padding_patches > 0 else None
-
- # Prepare attention mask
- attention_mask = aspect_ratio_mask.reshape(
- batch_size * num_concurrent_media, -1)
- from transformers.models.mllama.modeling_mllama import \
- _prepare_aspect_ratio_attention_mask
- attention_mask = _prepare_aspect_ratio_attention_mask(
- aspect_ratio_mask=attention_mask,
- num_patches=self.num_patches,
- target_length=hidden_state.shape[2],
- dtype=self.dtype,
- )
-
- # Apply encoder
- hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1,
- dim)
- output = self.transformer(
- hidden_state,
- attention_mask=attention_mask,
- output_hidden_states=True,
- output_attentions=output_attentions,
- )
- hidden_state = output[0]
-
- hidden_state = self.layernorm_post(hidden_state)
-
- # Apply global encoder
- hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
- num_tiles,
- num_patches + num_padding_patches,
- dim)
- hidden_state = self.post_tile_positional_embedding(
- hidden_state, aspect_ratio_ids)
- hidden_state = hidden_state.reshape(
- batch_size * num_concurrent_media,
- num_tiles * (num_patches + num_padding_patches), dim)
- global_output = self.global_transformer(
- hidden_state,
- attention_mask=attention_mask,
- output_hidden_states=output_hidden_states,
- output_attentions=output_attentions,
- )
- hidden_state = global_output[0]
-
- # Remove padding form hidden state
- hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
- num_tiles,
- num_patches + num_padding_patches,
- dim)
- hidden_state = hidden_state[:, :, :slice_index]
- hidden_state = hidden_state.reshape(batch_size, num_concurrent_media,
- num_tiles, num_patches, dim)
-
- # Collect intermediate layer outputs from encoder output
- all_intermediate_hidden_states = output[1]
- # rewrite to sync device during accelerate pipeline parallel
- device = hidden_state.device
- all_intermediate_hidden_states = [
- s.to(device) for s in all_intermediate_hidden_states
- ]
- intermediate_hidden_states = torch.stack(
- all_intermediate_hidden_states, dim=-1)
- intermediate_hidden_states = intermediate_hidden_states[
- ..., self.intermediate_layers_indices]
-
- # Remove padding from intermediate hidden states
- intermediate_hidden_states = intermediate_hidden_states.reshape(
- batch_size * num_concurrent_media, num_tiles,
- num_patches + num_padding_patches, -1)
- intermediate_hidden_states = intermediate_hidden_states[:, :, :
- slice_index]
- intermediate_hidden_states = intermediate_hidden_states.reshape(
- batch_size, num_concurrent_media, num_tiles, num_patches, -1)
-
- # Concatenate final hidden state and intermediate hidden states
- hidden_state = torch.cat([hidden_state, intermediate_hidden_states],
- dim=-1)
-
- if output_hidden_states:
- hidden_states = tuple(all_intermediate_hidden_states) + tuple(
- global_output[1])
- else:
- hidden_states = None
-
- if output_attentions:
- # global transformer in contrast to `self.transformer` doesn't
- # always return hidden states so we might go index out-of-range
- global_attn = tuple(
- global_output[2]) if output_hidden_states else tuple(
- global_output[1])
- attentions = tuple(output[2]) + global_attn
- else:
- attentions = None
-
- if not return_dict:
- return tuple(v for v in [hidden_state, hidden_states, attentions]
- if v is not None)
-
- return BaseModelOutput(
- last_hidden_state=hidden_state,
- hidden_states=hidden_states,
- attentions=attentions,
- )
def check_transformers():
- """check qwen_vl_utils."""
try:
from transformers import MllamaForConditionalGeneration # noqa: F401
except ImportError:
@@ -202,85 +20,60 @@ class MllamaVLModel(VisonModel):
_arch = 'MllamaForConditionalGeneration'
- def build_model(self):
- check_transformers()
-
- from transformers.models.mllama.modeling_mllama import \
- MllamaVisionModel
- MllamaVisionModel.forward = MllamaVisionModelPatch.forward
- MllamaVisionModel.apply_class_embedding = MllamaVisionModelPatch.apply_class_embedding # noqa
- from accelerate import init_empty_weights
- with init_empty_weights():
- config = self.hf_config
- config.quantization_config = {} # disable vision part quantization
- # disable accelerate check_tied_parameters_in_config
- config.tie_word_embeddings = False
- from transformers import MllamaForConditionalGeneration
- model = MllamaForConditionalGeneration._from_config(config)
- if not self.with_llm:
- del model.language_model
- else:
- self.vl_model = model
-
- from accelerate import load_checkpoint_and_dispatch
- with disable_logging():
- load_checkpoint_and_dispatch(
- model=model,
- checkpoint=self.model_path,
- device_map='auto' if not self.with_llm else {'': 'cpu'},
- max_memory=self.max_memory,
- no_split_module_classes=[
- 'MllamaPrecomputedPositionEmbedding',
- 'MllamaPrecomputedAspectRatioEmbedding',
- 'MllamaVisionEncoderLayer'
- ],
- dtype=config.torch_dtype)
-
- self.model = model.eval()
-
- # processor
+ def build_preprocessor(self):
from transformers import AutoProcessor
self.processor = AutoProcessor.from_pretrained(self.model_path)
self.image_token_id = 128256
- @torch.no_grad()
- def forward(self,
- images: List[Image],
- params: List[Dict] = None) -> List[torch.Tensor]:
- """forward."""
- # only support image input
- if params is not None:
- assert len(images) == len(
- params), 'different length of images and params'
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to the spec of `super().preprocess`"""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ results = self.processor.image_processor(images=image,
+ return_tensors='pt')
+ results.update(image_size=image.size,
+ image_tokens=1,
+ image_token_id=self.image_token_id)
+ outputs.append(results)
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
+
+ def build_model(self):
+ check_transformers()
+ if self.with_llm:
+ from transformers import MllamaForConditionalGeneration
+ model = MllamaForConditionalGeneration.from_pretrained(
+ self.model_path, device_map='cpu')
+ self.vl_model = model
else:
- params = [{}] * len(images)
- # resize images with abnormal shape
- # TODO try catch image feature extraction in pipeline and
- # throw error back to users
- for i, image in enumerate(images):
- size = image.size
- if any([s < 3 for s in size]):
- images[i] = image.resize([s * 3 for s in size])
- image_inputs = self.processor.image_processor(images=images,
- return_tensors='pt')
- pixel_values = image_inputs['pixel_values'].to(
- self.model.vision_model.device)
- pixel_values = pixel_values.type(self.model.vision_model.dtype)
- aspect_ratio_ids = image_inputs['aspect_ratio_ids'].to(
- self.model.vision_model.device)
- aspect_ratio_mask = image_inputs['aspect_ratio_mask'].to(
- self.model.vision_model.device)
- vision_outputs = self.model.vision_model(
- pixel_values=pixel_values,
- aspect_ratio_ids=aspect_ratio_ids,
- aspect_ratio_mask=aspect_ratio_mask,
- output_hidden_states=False,
- output_attentions=False,
- return_dict=True)
- cross_attention_states = vision_outputs[0]
- cross_attention_states = self.model.multi_modal_projector(
- cross_attention_states)
- _, bsz, _, _, image_token_dim = tuple(cross_attention_states.shape)
- cross_attention_states = cross_attention_states.view(
- bsz, -1, image_token_dim).split([1] * len(images))
- return cross_attention_states
+ raise NotImplementedError('turbomind has not supported mllama yet')
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = '<|image|>'
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
+ continue
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ content = [
+ item['text'] for item in message['content']
+ if item['type'] == 'text'
+ ]
+ prompt = (IMAGE_TOKEN) * n_images + content[0]
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/molmo.py b/lmdeploy/vl/model/molmo.py
index 9abae7a309..eccf62ebb6 100644
--- a/lmdeploy/vl/model/molmo.py
+++ b/lmdeploy/vl/model/molmo.py
@@ -3,11 +3,9 @@
from typing import Dict, List
import torch
-from PIL.Image import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from lmdeploy.utils import get_logger
-from lmdeploy.vl.constants import IMAGE_TOKEN
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import disable_logging
@@ -20,20 +18,26 @@ class MolmoVisionModel(VisonModel):
_arch = 'MolmoForCausalLM'
+ def build_preprocessor(self):
+ self.processor = AutoProcessor.from_pretrained(self.model_path,
+ trust_remote_code=True,
+ torch_dtype=torch.half,
+ device_map='auto')
+
def build_model(self):
- """Load model."""
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
with init_empty_weights():
- config = self.hf_config
- model = AutoModelForCausalLM.from_config(config,
+ model = AutoModelForCausalLM.from_config(self.hf_config,
trust_remote_code=True)
+
+ self.vl_model = model
if not self.with_llm:
# Remove nn modules other than embedding from the LLM model
for key in ['emb_drop', 'ln_f', 'blocks', 'ff_out']:
del model.model.transformer[key]
- self.token_embedding = model.model.transformer.wte
- else:
- self.vl_model = model
+ self.token_embedding = model.model.transformer.wte
with disable_logging():
load_checkpoint_and_dispatch(
@@ -43,118 +47,161 @@ def build_model(self):
max_memory=self.max_memory,
no_split_module_classes=[
'ResidualAttentionBlock', 'Embedding'
- ])
+ ],
+ dtype=torch.half)
# We need eval mode to freeze the weights in model, thus,
# avoid randomness in inference.
self.model = model.eval()
- self.config = config
- self.processor = AutoProcessor.from_pretrained(self.model_path,
- trust_remote_code=True,
- torch_dtype='auto',
- device_map='auto')
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to the `super.preprocess() for spec."""
+ for i, message in enumerate(messages):
+ if not isinstance(message['content'], List):
+ continue
+ images = [
+ x['image'] for x in message['content'] if x['type'] == 'image'
+ ]
+ content = [
+ x['text'] for x in message['content'] if x['type'] == 'text'
+ ]
+ prompt = f' User: {content[0]}'
+ tokens = self.processor.tokenizer.encode(prompt,
+ add_special_tokens=False)
+ # preprocess images. The output is a dict, which is
+ # {
+ # 'input_ids': torch.Tensor,
+ # 'images': torch.Tensor, # (n_patch, d_model)
+ # 'image_input_idx': torch.Tensor, # (n_patch, d_model)
+ # 'image_masks': torch.Tensor, # (n_patch, d_model)
+ # }
+ result = self.processor.process(images=images, tokens=tokens)
+ # remove the bos from input_ids which is prepended by molmo's
+ # processor
+ input_ids = result['input_ids'][1:]
+ result.update(input_ids=input_ids)
+ messages[i].update(preprocess=result)
+ return messages
@torch.no_grad()
def forward(self,
- images: List[Image],
- params: List[Dict] = None) -> List[Dict]:
- """forward the model with given input.
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
Args:
- images (List): [None] it is not used
- params (List): the inputs after precessing GPT4V messages in
- `MolmoChatTemplateWrapper`. Its format is like the following:
- [[
- {'role': 'user', 'content': 'user prompt'},
- {'role': 'asssistant', 'content': 'assistant prompt'},
- {'role': 'user', 'content': 'user prompt', 'images': [PIL image list]},
- ...
- ]]
- """ # noqa
-
- messages = params[0]
- assert isinstance(messages, List)
- # append an assistant message to `messages`
- messages.append(dict(role='assistant', content=''))
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ for i, message in enumerate(messages):
+ if 'preprocess' not in message.keys():
+ continue
+ inputs = message['preprocess']
+ # get input_ids of embedding
+ inputs = {
+ k: v.to(self.model.device).unsqueeze(0)
+ for k, v in inputs.items()
+ }
+ input_ids = inputs['input_ids']
+ # (batch_size, num_image, num_patch, d_model)
+ images = inputs['images']
+ # (batch_size, num_image, num_patch)
+ image_input_idx = inputs['image_input_idx']
+ image_masks = inputs['image_masks']
+ batch_size, seq_len = input_ids.size()
+ assert batch_size == 1
+ input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
+ embeddings = self.model.model.transformer.wte(input_ids)
+ images = images.to(self.model.dtype)
+ image_masks = image_masks.to(self.model.dtype)
+ logger.info(f'vision forward shape: {images.shape}')
+ image_features, _ = self.model.model.vision_backbone(
+ images, image_masks)
+ num_image, num_patch = image_features.shape[1:3]
+ assert image_input_idx.shape == (batch_size, num_image, num_patch)
+
+ # insert the image feature into the embedding.
+ image_features = image_features.view(batch_size,
+ num_image * num_patch, -1)
+ image_input_idx = image_input_idx.view(batch_size,
+ num_image * num_patch)
+ valid = image_input_idx >= 0
+ batch_idx = torch.arange(batch_size, device=embeddings.device)
+ batch_idx = torch.tile(batch_idx[:, None],
+ [1, image_features.shape[1]])
+ image_features = image_features.to(embeddings.device)
+ # Since we remove bos_id from input_ids during `preprocess`,
+ # the index `image_input_idx[valid]` should be shift to left
+ # by subtracting 1
+ index = image_input_idx[valid] - 1
+ embeddings[batch_idx[valid], index] += image_features[valid]
+ assert embeddings.shape[:2] == (batch_size, seq_len)
+ messages[i].update(
+ dict(forward=dict(input_ids=input_ids.flatten(),
+ embeddings=embeddings)))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages):
+ prompt = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ role, content = message['role'], message['content']
+ if isinstance(content, List):
+ n_images = len([1 for x in content if x['type'] == 'image'])
+ content = [x['text'] for x in content if x['type'] == 'text']
+ prompt.append(' User: ' + (IMAGE_TOKEN + '\n') * n_images +
+ content[0])
+ else:
+ if role == 'user':
+ prompt.append(f' User: {content}')
+ elif role == 'assistant':
+ prompt.append(f' Assistant:{content}')
+ else:
+ assert 0, f'molmo does not support role {role}, message is {message}' # noqa
+ prompt.append(' Assistant:')
+ return ''.join(prompt)
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ assert 0, 'molmo is not supported by pytorch engine'
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
# results is a list of tuple(input_ids, embeddings)
results = []
- # the concat prompt. It is not used during inference but to adhere the
- # interface definition of `_get_prompt_input` in `class VLAsyncEngine`
- prompts = ''
# Prepend BOS
# qwen2 and olmo do not have a BOS, and instead use EOS as a generic
# separator token.
bos = (self.processor.tokenizer.bos_token_id
or self.processor.tokenizer.eos_token_id)
results.append(([bos], None))
+
for i, message in enumerate(messages):
- if 'images' in message.keys():
- prompts += ' User: ' + (IMAGE_TOKEN + '\n') * len(
- message['images']) + message['content']
- prompt = f' User: {message["content"]}'
- tokens = self.processor.tokenizer.encode(
- prompt, add_special_tokens=False)
- # preprocess images. The output is a dict
- inputs = self.processor.process(images=message['images'],
- tokens=tokens)
- inputs = {
- k: v.to(self.model.device).unsqueeze(0)
- for k, v in inputs.items()
- }
- input_ids = inputs['input_ids']
- # remove the bos from input_ids which is prepended by molmo's
- # processor
- input_ids = input_ids[:, 1:]
- images = inputs[
- 'images'] # (batch_size, num_image, num_patch, d_model)
- image_input_idx = inputs[
- 'image_input_idx'] # (batch_size, num_image, num_patch)
- image_masks = inputs['image_masks']
- batch_size, seq_len = input_ids.size()
- assert batch_size == 1
-
- # Get embeddings of input.
- if input_ids is not None:
- input_ids = input_ids * (input_ids != -1).to(
- input_ids.dtype)
- embeddings = self.model.model.transformer.wte(input_ids)
- image_features, _ = self.model.model.vision_backbone(
- images, image_masks)
- num_image, num_patch = image_features.shape[1:3]
- assert image_input_idx.shape == (batch_size, num_image,
- num_patch)
-
- # insert the image feature into the embedding.
- image_features = image_features.view(batch_size,
- num_image * num_patch, -1)
- image_input_idx = image_input_idx.view(batch_size,
- num_image * num_patch)
-
- valid = image_input_idx >= 0
- batch_idx = torch.arange(batch_size, device=embeddings.device)
- batch_idx = torch.tile(batch_idx[:, None],
- [1, image_features.shape[1]])
- image_features = image_features.to(embeddings.device)
- embeddings[batch_idx[valid],
- image_input_idx[valid]] += image_features[valid]
- assert embeddings.shape[:2] == (batch_size, seq_len)
- results.append((input_ids.flatten().tolist(), embeddings))
+ prompt = ''
+ role, content = message['role'], message['content']
+ if isinstance(content, List):
+ forward_result = message.pop('forward')
+ input_ids = forward_result['input_ids']
+ embeddings = forward_result['embeddings']
+ results.append((input_ids.tolist(), embeddings))
else:
- role = message['role']
- content = message['content']
- assert isinstance(content, str)
- prompt = ''
if role == 'user':
prompt = f' User: {content}'
elif role == 'assistant':
prompt = f' Assistant:{content}'
else:
assert 0, f'molmo does not support role {role}, message is {message}' # noqa
+ if i == len(messages) - 1:
+ # the last message
+ assert role == 'user', f'the role of last message is expected to be user, but got {role}' # noqa
+ prompt += ' Assistant:'
+ if prompt:
input_ids = self.processor.tokenizer.encode(
prompt, add_special_tokens=False)
results.append((input_ids, None))
- prompts += prompt
# concat input_ids from results, calculate the range in the input_ids
# where embeddings will be copied to
@@ -169,9 +216,9 @@ def forward(self,
input_embedding_ranges.append((start, end))
input_ids += _input_ids
start += len(_input_ids)
- return [
- dict(prompt=prompts,
- input_ids=input_ids,
- input_embeddings=input_embeddings,
- input_embedding_ranges=input_embedding_ranges)
- ]
+
+ prompt = self.proc_messages(messages)
+ return dict(prompt=prompt,
+ input_ids=input_ids,
+ input_embeddings=input_embeddings,
+ input_embedding_ranges=input_embedding_ranges)
diff --git a/lmdeploy/vl/model/phi3_vision.py b/lmdeploy/vl/model/phi3_vision.py
index 032b8404da..ff00b5d1d9 100644
--- a/lmdeploy/vl/model/phi3_vision.py
+++ b/lmdeploy/vl/model/phi3_vision.py
@@ -1,198 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
-import warnings
-from typing import List
+from typing import Dict, List
-import torch
-from PIL.Image import Image
from transformers import AutoProcessor
-from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
-from lmdeploy.vl.model.utils import disable_logging
-
-
-# from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py # noqa E501
-def _process_image_embedding(self, pixel_values: torch.Tensor,
- image_sizes: torch.Tensor):
- """process image embedding."""
- img_embeds = pixel_values
- img_sizes = image_sizes
- target_device = pixel_values.device
- target_dtype = pixel_values.dtype
- if self.use_hd_transform and img_sizes is not None and len(img_sizes):
- assert img_embeds.ndim == 5, f'img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform' # noqa E501
- # img_embeds: (num_images, max_num_crops, 3, H, W)
- # img_sizes: (num_images, 2).view(1, -1)
-
- bs = img_embeds.shape[0]
- # Nx(HW)xC
- img_features = self.get_img_features(img_embeds.flatten(0, 1))
- base_feat_height = base_feat_width = int(img_features.shape[1]**0.5)
-
- assert base_feat_height == 24 and base_feat_width == 24, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect 24x24 features for hd transform' # noqa E501
-
- # bs x max_num_crops x (24x24) x C
- img_features = img_features.view(bs, -1,
- base_feat_height * base_feat_width,
- self.image_dim_out)
- C = self.image_dim_out
- H = base_feat_height
-
- output_imgs = []
- output_len = []
- # training is tensor, inference is list
- if isinstance(img_sizes, torch.Tensor):
- img_sizes = img_sizes.view(-1, 2)
- for _bs in range(bs):
- h, w = img_sizes[_bs]
- h = h // 336
- w = w // 336
- B_ = h * w
-
- # 1 x (24x24) x 1024
- global_img_feature = img_features[_bs, :1]
-
- # 1 x 12 x 12 x 4096
- glb_img = global_img_feature.reshape(1, H, H, C).reshape(
- 1, H // 2, 2, H // 2, 2,
- C).contiguous().permute(0, 1, 3, 2, 4,
- 5).reshape(1, H // 2, H // 2,
- 4 * C).contiguous()
- temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1)
-
- # 1 x 156 x 4096
- glb_img = torch.cat([glb_img, temp_glb_GN],
- dim=2).reshape(1, -1, 4 * C)
-
- # (max_num_crops-1) x (12x12) x C
- sub_img = img_features[_bs, 1:]
- # 16x574x1024
- # get rid of padding sub_img
- sub_img = sub_img[:B_]
-
- # (num_crops, 12, 2, 12, 2, 1024)->(num_crops, 12, 12, 2, 2, 1024)
- # -> (num_crops, 12*12, 4*1024)
- sub_img = sub_img.reshape(B_, H, H, C).reshape(
- B_, H // 2, 2, H // 2, 2,
- C).contiguous().permute(0, 1, 3, 2, 4,
- 5).reshape(B_, -1, 4 * C).contiguous()
- sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute(
- 0, 1, 3, 2, 4, 5).reshape(1, h * 12, w * 12, 4 * C)
- temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1)
- sub_img = torch.cat([sub_img, temp_sub_GN],
- dim=2).reshape(1, -1, 4 * C)
- # (1, num_img_tokens, 1024*4)
-
- # glb + sub
- if self.hd_transform_order == 'glb_sub':
- output_imgs.append(
- torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
- elif self.hd_transform_order == 'sub_glb':
- output_imgs.append(
- torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
- else:
- raise NotImplementedError(
- f'hd_transform_order = {self.hd_transform_order}'
- ) # noqa E501
-
- temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
- assert temp_len == output_imgs[-1].shape[
- 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' # noqa E501
- output_len.append(temp_len)
-
- img_set_tensor = []
- for _output_img in output_imgs:
- img_feature_proj = self.img_projection(
- _output_img.to(target_device).to(target_dtype))
- img_set_tensor.append(img_feature_proj)
- elif img_embeds.ndim == 4:
- tt = (self.get_img_features(img_embeds).to(target_device).to(
- target_dtype).reshape(-1, self.image_dim_out))
- img_set_tensor = self.img_projection(tt) # adapted visual features.
- elif img_embeds.ndim == 3:
- tt = (img_embeds.to(target_device).to(target_dtype).view(
- -1, self.image_dim_out))
- img_set_tensor = self.img_projection(tt) # adapted visual features.
- else:
- raise NotImplementedError
- return img_set_tensor
+from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel
@VISION_MODELS.register_module()
-class Phi3VisionModel(VisonModel):
- """Llava hf vision model."""
+class Phi3VisionModel(LlavaHfVisionModel):
+ """Phi3-vision model."""
_arch = 'Phi3VForCausalLM'
- def build_model(self):
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
- from accelerate.utils import get_balanced_memory, infer_auto_device_map
-
- with init_empty_weights(), warnings.catch_warnings():
- warnings.simplefilter('ignore')
- from transformers import AutoModelForCausalLM
- model = AutoModelForCausalLM.from_config(self.hf_config,
- trust_remote_code=True)
- if not self.with_llm:
- del model.lm_head
- del model.model.layers
- del model.model.norm
- del model.model.embed_tokens
- del model.model.vision_embed_tokens.wte
- else:
- self.vl_model = model
-
- no_split_module_classes = ['CLIPEncoderLayer']
- max_memory = get_balanced_memory(
- model,
- max_memory=self.max_memory,
- dtype=torch.half,
- no_split_module_classes=no_split_module_classes)
- device_map = infer_auto_device_map(
- model,
- no_split_module_classes=no_split_module_classes,
- max_memory=max_memory,
- dtype=torch.half)
- same_device_keys = [('model.vision_embed_tokens.img_projection',
- 'model.vision_embed_tokens.sub_GN',
- 'model.vision_embed_tokens.glb_GN')]
- for keys in same_device_keys:
- keys = [k for k in keys if k in device_map]
- if len(keys) <= 1:
- continue
- for k in keys[1:]:
- device_map[k] = device_map[keys[0]]
-
- with disable_logging():
- load_checkpoint_and_dispatch(
- model=model,
- checkpoint=self.model_path,
- device_map=device_map if not self.with_llm else {'': 'cpu'},
- no_split_module_classes=no_split_module_classes,
- dtype=torch.half)
-
- model.eval()
- self.model = model
- # processor
+ def build_preprocessor(self):
processor = AutoProcessor.from_pretrained(self.model_path,
trust_remote_code=True)
if hasattr(processor, 'tokenizer'):
del processor.tokenizer
- processor.prtokenizer = None
- self.processor = processor.image_processor
+ processor.tokenizer = None
self.processor = processor
- @torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- process_outputs = self.processor.image_processor(
- images, return_tensors='pt').to(device=self.model.device,
- dtype=self.model.dtype)
- pixel_values = process_outputs['pixel_values']
- image_sizes = process_outputs['image_sizes']
- image_features = _process_image_embedding(
- self.model.model.vision_embed_tokens,
- pixel_values=pixel_values,
- image_sizes=image_sizes)
- outputs = [x.squeeze() for x in image_features]
- return outputs
+ def build_model(self):
+ if self.with_llm:
+ from transformers import AutoModelForCausalLM
+ self.vl_model = AutoModelForCausalLM.from_pretrained(
+ self.model_path, device_map='cpu', trust_remote_code=True)
+ else:
+ raise NotImplementedError('turbomind has not supported phi3v yet')
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refers to `super.preprocess() for spec."""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ result = self.processor.image_processor(image, return_tensors='pt')
+ h = result['image_sizes'][0][0].item() // 336
+ w = result['image_sizes'][0][1].item() // 336
+ image_tokens = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
+ result.update(
+ dict(image_size=image.size,
+ image_tokens=image_tokens,
+ image_token_id=0))
+ outputs.append(result)
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
diff --git a/lmdeploy/vl/model/qwen.py b/lmdeploy/vl/model/qwen.py
index 3968f27d97..49631ccf35 100644
--- a/lmdeploy/vl/model/qwen.py
+++ b/lmdeploy/vl/model/qwen.py
@@ -1,14 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import List
+from typing import Dict, List
import torch
-from PIL.Image import Image
from transformers import AutoModelForCausalLM
+from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import disable_logging
+logger = get_logger('lmdeploy')
+
@VISION_MODELS.register_module()
class QwenVisionModel(VisonModel):
@@ -16,19 +18,33 @@ class QwenVisionModel(VisonModel):
_arch = 'QWenLMHeadModel'
+ def build_preprocessor(self):
+ from torchvision import transforms
+ from torchvision.transforms import InterpolationMode
+ mean = (0.48145466, 0.4578275, 0.40821073)
+ std = (0.26862954, 0.26130258, 0.27577711)
+ image_size = self.hf_config.visual['image_size']
+ self.image_transform = transforms.Compose([
+ transforms.Resize((image_size, image_size),
+ interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=mean, std=std),
+ ])
+
def build_model(self):
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights
with init_empty_weights():
config = self.hf_config
config.quantization_config = {} # disable vision part quantization
model = AutoModelForCausalLM.from_config(config,
trust_remote_code=True)
+ self.vl_model = model
if not self.with_llm:
del model.lm_head
for key in ['wte', 'h', 'ln_f']:
setattr(model.transformer, key, None)
- else:
- self.vl_model = model
from accelerate.utils import get_balanced_memory, infer_auto_device_map
max_memory = get_balanced_memory(
@@ -60,13 +76,86 @@ def build_model(self):
self.model = model.transformer.visual.eval()
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refers to `super.preprocess() for spec."""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ pixel_values = self.image_transform(image)
+ outputs.append(
+ dict(pixel_values=pixel_values,
+ image_size=image.size,
+ image_tokens=256,
+ image_token_id=0))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
+
@torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- outputs = [x.convert('RGB') for x in images]
- outputs = [self.model.image_transform(x) for x in outputs]
- outputs = torch.stack(outputs, dim=0)
- outputs = self.model(outputs)
- outputs = torch.split(outputs, 1, dim=0)
- outputs = [x.squeeze() for x in outputs]
- return outputs
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ pixel_values = torch.stack(pixel_values, dim=0)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ feats = self.model(pixel_values)
+ feats = torch.split(feats, 1, dim=0)
+ outputs.extend([x.squeeze() for x in feats])
+ messages.append(dict(role='forward', content=outputs))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
+ continue
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ content = [
+ x['text'] for x in message['content'] if x['type'] == 'text'
+ ]
+ prompt = content[0]
+ if IMAGE_TOKEN in prompt:
+ pass
+ else:
+ prompt = ''.join([
+ f'Picture {str(i)}:{IMAGE_TOKEN}\n'
+ for i in range(n_images)
+ ]) + prompt
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/qwen2.py b/lmdeploy/vl/model/qwen2.py
index 3eb3c1541c..ed9da332e0 100644
--- a/lmdeploy/vl/model/qwen2.py
+++ b/lmdeploy/vl/model/qwen2.py
@@ -1,12 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
-
from typing import Dict, List
import torch
-from PIL.Image import Image
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
-from lmdeploy.vl.model.utils import disable_logging
def check_qwen_vl_deps_install():
@@ -15,7 +12,7 @@ def check_qwen_vl_deps_install():
import qwen_vl_utils # noqa: F401
except ImportError:
raise ImportError(
- 'please install qwen_vl_utils by pip install qwen_vl_utils' # noqa: E501
+ 'please install qwen_vl_utils by `pip install qwen_vl_utils`' # noqa: E501
)
try:
from transformers import Qwen2VLForConditionalGeneration # noqa: F401
@@ -31,85 +28,105 @@ class Qwen2VLModel(VisonModel):
_arch = 'Qwen2VLForConditionalGeneration'
+ def build_preprocessor(self):
+ check_qwen_vl_deps_install()
+ from transformers import AutoProcessor
+ self.processor = AutoProcessor.from_pretrained(self.model_path)
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to `super().preprocess()` for spec."""
+ from qwen_vl_utils import process_vision_info
+
+ images = self.collect_images(messages)
+ optional_keys = {
+ 'resized_height', 'resized_width', 'min_pixels', 'max_pixels'
+ }
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+
+ item = dict(type='image', image=image)
+ item.update({
+ key: params[key]
+ for key in params.keys() if key in optional_keys
+ })
+ image_inputs, _ = process_vision_info([dict(content=[item])])
+ result = self.processor.image_processor(images=image_inputs,
+ videos=None,
+ return_tensors='pt')
+ merge_length = self.processor.image_processor.merge_size**2
+ image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length
+ result.update(
+ dict(image_size=image.size,
+ image_tokens=image_tokens,
+ image_token_id=0))
+ outputs.append(result)
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
+
def build_model(self):
check_qwen_vl_deps_install()
from transformers import Qwen2VLForConditionalGeneration
if self.with_llm:
- model = Qwen2VLForConditionalGeneration.from_pretrained(
- self.hf_config._name_or_path, trust_remote_code=True)
- model.half()
- self.vl_model = model
+ self.vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
+ self.model_path, device_map='cpu')
else:
- from accelerate import init_empty_weights
- with init_empty_weights():
- config = self.hf_config
- config.quantization_config = {
- } # disable vision part quantization
- # disable accelerate check_tied_parameters_in_config
- # for Qwen2-VL-2B-Instruct
- config.tie_word_embeddings = False
-
- model = Qwen2VLForConditionalGeneration._from_config(config)
- del model.model
- del model.lm_head
- model.half()
- from accelerate import load_checkpoint_and_dispatch
- with disable_logging():
- load_checkpoint_and_dispatch(
- model=model,
- checkpoint=self.model_path,
- device_map='auto' if not self.with_llm else {'': 'cpu'},
- max_memory=self.max_memory,
- no_split_module_classes=['Qwen2VLVisionBlock'],
- dtype=torch.half)
-
- self.model = model.eval()
-
- # processor
- from transformers import AutoProcessor
- self.processor = AutoProcessor.from_pretrained(self.model_path)
+ raise NotImplementedError(
+ 'turbomind has not supported qwen2-vl yet')
@torch.no_grad()
def forward(self,
- images: List[Image],
- params: List[Dict] = None) -> List[torch.Tensor]:
- """forward."""
- # only support image input
- if params is not None:
- assert len(images) == len(
- params), 'different length of images and params'
- else:
- params = [{}] * len(images)
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
- from qwen_vl_utils import process_vision_info
- images = [x.convert('RGB') for x in images]
- content = []
- optional_keys = [
- 'resized_height', 'resized_width', 'min_pixels', 'max_pixels'
- ]
- for image, param in zip(images, params):
- item = dict(type='image', image=image)
- item.update({k: param[k] for k in optional_keys if k in param})
- content.append(item)
- messages = [dict(content=content)]
- image_inputs, _ = process_vision_info(messages)
- image_inputs = self.processor.image_processor(images=image_inputs,
- videos=None,
- return_tensors='pt')
- pixel_values = image_inputs['pixel_values'].to(
- self.model.visual.get_device())
- image_grid_thw = image_inputs['image_grid_thw'].to(
- self.model.visual.get_device())
- pixel_values = pixel_values.type(self.model.visual.get_dtype())
- image_embeds = self.model.visual(pixel_values,
- grid_thw=image_grid_thw).cpu()
- merge_length = self.processor.image_processor.merge_size**2
- split_size = image_inputs['image_grid_thw'].prod(dim=1) // merge_length
- image_embeds = image_embeds.split(split_size.tolist())
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ assert 0, 'TODO: support turbomind engine'
- outputs = []
- for i, embeddings in enumerate(image_embeds):
- outputs.append(
- dict(embeddings=embeddings,
- grid_thw=image_inputs['image_grid_thw'][i].tolist()))
- return outputs
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
+ continue
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ content = [
+ item['text'] for item in message['content']
+ if item['type'] == 'text'
+ ]
+ prompt = content[0]
+ if IMAGE_TOKEN in prompt and '<|vision_start|>' not in prompt:
+ prompt = prompt.replace(
+ IMAGE_TOKEN,
+ f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>')
+ else:
+ # Qwen2-VL-2B-Instruct will concat image and user prompt
+ # according to their order in the content list
+ # we insert image token before user prompt by default. The
+ # user can use custom image token position if they want the
+ # same decorated prompt as Qwen2-VL
+ prompt = f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>' * \
+ n_images + prompt
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ """return to the information needed by pytorch engine."""
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/xcomposer2.py b/lmdeploy/vl/model/xcomposer2.py
index 96bc900c02..3c72d0c29f 100644
--- a/lmdeploy/vl/model/xcomposer2.py
+++ b/lmdeploy/vl/model/xcomposer2.py
@@ -5,7 +5,7 @@
import sys
import warnings
from contextlib import contextmanager
-from typing import Any, List, Tuple
+from typing import Any, Dict, List, Tuple
import torch
from PIL.Image import Image
@@ -19,6 +19,17 @@
logger = get_logger('lmdeploy')
+def check_xcomposer_install():
+ try:
+ # WARNING! we have to do this otherwise the model_type is wrong for
+ # xcomposer2d5
+ import decord # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "No module named 'decord'. Please install decord by `pip install decord`" # noqa
+ )
+
+
class ModelType(enum.Enum):
"""Request type."""
XCOMPOSER2 = enum.auto()
@@ -83,6 +94,17 @@ def init_empty_vit(model_path):
class Xcomposer2VisionModel(VisonModel):
"""InternLM-Xcomposer2 vision model."""
+ def __init__(self,
+ model_path: str,
+ with_llm: bool = False,
+ max_memory: Dict[int, int] = None,
+ hf_config: AutoConfig = None,
+ backend: str = ''):
+ super().__init__(model_path, with_llm, max_memory, hf_config, backend)
+ check_xcomposer_install()
+ self.model_type, self.module = get_xcomposer_type(self.model_path)
+ logger.info(f'matching type of {self.model_type}')
+
@classmethod
def match(cls, config: AutoConfig):
"""check whether the config match the model."""
@@ -94,7 +116,37 @@ def match(cls, config: AutoConfig):
return True
return False
+ def build_preprocessor(self):
+
+ import torchvision.transforms as transforms
+ from torchvision.transforms.functional import InterpolationMode
+
+ if self.model_type in [
+ ModelType.XCOMPOSER2D5, ModelType.XCOMPOSER2_4KHD
+ ]:
+ self.HD_transform = self.module
+ self.vis_processor = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
+ (0.26862954, 0.26130258, 0.27577711)),
+ ])
+ self.preprocess_func = (self._preprocess_2d5 if self.model_type
+ == ModelType.XCOMPOSER2D5 else
+ self._preprocess_4khd_7b)
+ else:
+ self.vis_processor = transforms.Compose([
+ transforms.Resize(
+ (self.hf_config.img_size, self.hf_config.img_size),
+ interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
+ (0.26862954, 0.26130258, 0.27577711)),
+ ])
+ self.preprocess_func = self._preprocess_7b
+
def build_model(self):
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights
with init_empty_weights(), warnings.catch_warnings(), \
init_empty_vit(self.model_path):
@@ -106,23 +158,10 @@ def build_model(self):
model.vit.resize_pos()
model.vit.vision_tower.vision_model.post_layernorm.to_empty(
device='cpu').half()
+ self.vl_model = model
if not self.with_llm:
del model.model
del model.output
- else:
- self.vl_model = model
-
- # additional components.
- model_type, module = get_xcomposer_type(self.model_path)
- logger.info(f'matching type of {model_type}')
- if model_type == ModelType.XCOMPOSER2D5:
- self.HD_transform = module
- self._forward_func = self._forward_2d5
- elif model_type == ModelType.XCOMPOSER2_4KHD:
- self.HD_transform = module
- self._forward_func = self._forward_4khd_7b
- else:
- self._forward_func = self._forward_7b
from accelerate.utils import get_balanced_memory, infer_auto_device_map
max_memory = get_balanced_memory(
@@ -156,51 +195,117 @@ def build_model(self):
self.model = model.eval()
- def _forward_2d5(self, images: List[Image]) -> List[torch.Tensor]:
- """internlm-xcomposer2d5-7b vit forward."""
- outputs = [x.convert('RGB') for x in images]
- hd_num = 6 if len(images) > 1 else 24
- outputs = [self.HD_transform(x, hd_num=hd_num) for x in outputs]
- outputs = [
- self.model.vis_processor(x).unsqueeze(0).to(dtype=torch.half)
- for x in outputs
- ]
- embeds, split = self.model.vit(outputs, self.model.plora_glb_GN,
- self.model.plora_sub_GN)
- embeds = self.model.vision_proj(embeds)
- embeds = torch.split(embeds, split, dim=1)
- embeds = [x.squeeze() for x in embeds]
- return embeds
-
- def _forward_7b(self, images: List[Image]) -> List[torch.Tensor]:
- """internlm-xcomposer2-7b vit forward."""
- outputs = [x.convert('RGB') for x in images]
- outputs = [
- self.model.vis_processor(x).unsqueeze(0).half() for x in outputs
- ]
- outputs = torch.cat(outputs, dim=0)
- outputs = self.model.vit(outputs)
- outputs = self.model.vision_proj(outputs)
- outputs = torch.split(outputs, 1, dim=0)
- outputs = [x.squeeze() for x in outputs]
- return outputs
-
- def _forward_4khd_7b(self, images: List[Image]) -> List[torch.Tensor]:
- """internlm-xcomposer2-4khd-7b vit forward."""
- outputs = [x.convert('RGB') for x in images]
- outputs = [self.HD_transform(x, hd_num=25) for x in outputs]
- outputs = [
- self.model.vis_processor(x).unsqueeze(0).to(dtype=torch.half)
- for x in outputs
- ]
- embeds, split = self.model.vit(outputs, self.model.plora_glb_GN,
- self.model.plora_sub_GN)
- embeds = self.model.vision_proj(embeds)
- embeds = torch.split(embeds, split, dim=1)
- embeds = [x.squeeze() for x in embeds]
- return embeds
+ def _preprocess_2d5(self, image: Image, params: Dict) -> Dict:
+ """image preprocessing for internlm-xcomposer2d5-7b."""
+ hd_num = params.get('hd_num', 24)
+ image = self.HD_transform(image, hd_num=hd_num)
+ pixel_values = self.vis_processor(image).unsqueeze(0).half()
+ w, h = image.size
+ n_token_per_image = int((h * w + 1) * 400 + 1 + (h + 1) * 20)
+ return pixel_values, n_token_per_image
+
+ def _preprocess_7b(self, image: Image, params: Dict) -> Dict:
+ """image preprocessing for internlm-xcomposer2-7b."""
+ pixel_values = self.vis_processor(image).unsqueeze(0).half()
+ return pixel_values, 256
+
+ def _preprocess_4khd_7b(self, image: Image, params: Dict) -> Dict:
+ """image preprocessing for internlm-xcomposer2-4khd-7b."""
+ image = self.HD_transform(image, hd_num=25)
+ pixel_values = self.vis_processor(image).unsqueeze(0).half()
+ w, h = image.size
+ n_token_per_image = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
+ return pixel_values, n_token_per_image
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to `super().preprocess() for spec."""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ pixel_values, n_token = self.preprocess_func(image, params)
+ outputs.append(
+ dict(pixel_values=pixel_values,
+ image_size=image.size,
+ image_tokens=n_token,
+ image_token_id=0))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
@torch.no_grad()
- def forward(self, images: List[Image]) -> List[torch.Tensor]:
- """forward."""
- return self._forward_func(images)
+ def forward(self,
+ messages: List[Dict],
+ max_batch_size: int = 1) -> List[Dict]:
+ """extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
+ inputs = inputs[0]
+ outputs = []
+ for idx in range(0, len(inputs), max_batch_size):
+ if self.model_type in [
+ ModelType.XCOMPOSER2D5, ModelType.XCOMPOSER2_4KHD
+ ]:
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ embeds, split = self.model.vit(pixel_values,
+ self.model.plora_glb_GN,
+ self.model.plora_sub_GN)
+ embeds = self.model.vision_proj(embeds)
+ embeds = torch.split(embeds, split, dim=1)
+ embeds = [x.squeeze() for x in embeds]
+ else:
+ pixel_values = [
+ x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
+ ]
+ pixel_values = torch.cat(pixel_values, dim=0)
+ logger.info(f'vision forward shape: {pixel_values.shape}')
+ embeds = self.model.vit(pixel_values)
+ embeds = self.model.vision_proj(embeds)
+ embeds = torch.split(embeds, 1, dim=0)
+ embeds = [x.squeeze() for x in embeds]
+ outputs.extend(embeds)
+ messages.append(dict(role='forward', content=outputs))
+ return messages
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
+ continue
+ n_images = len(
+ [1 for x in message['content'] if x['type'] == 'image'])
+ content = [
+ item['text'] for item in message['content']
+ if item['type'] == 'text'
+ ]
+ prompt = ' '.join([IMAGE_TOKEN] * n_images) + content[0]
+ prompt_messages.append(dict(role='user', content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template,
+ sequence_start)
+ return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
+ sequence_start)
diff --git a/lmdeploy/vl/model/yi.py b/lmdeploy/vl/model/yi.py
index 34b993322e..f8d3a907ff 100644
--- a/lmdeploy/vl/model/yi.py
+++ b/lmdeploy/vl/model/yi.py
@@ -1,12 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from contextlib import contextmanager
+from os import path as osp
+from typing import Dict, List
import torch.nn as nn
from transformers import AutoConfig
from lmdeploy.vl.model.base import VISION_MODELS
-from lmdeploy.vl.model.llava import LlavaVisionModel, check_llava_install
+from lmdeploy.vl.model.llava import (LlavaVisionModel, check_llava_install,
+ process_images)
from .utils import disable_transformers_logging, rewrite_ctx
@@ -96,8 +99,22 @@ def match(cls, config: AutoConfig):
return True
return False
+ def build_preprocessor(self):
+ from transformers import CLIPImageProcessor
+ vision_tower_name = osp.join(self.model_path,
+ self.hf_config.mm_vision_tower)
+ self.image_processor = CLIPImageProcessor.from_pretrained(
+ vision_tower_name)
+ config = AutoConfig.from_pretrained(vision_tower_name)
+ image_size = config.image_size
+ patch_size = config.patch_size
+ self.n_token_per_image = (image_size // patch_size)**2
+ if self.hf_config.mm_vision_select_feature == 'cls_patch':
+ self.n_token_per_image += 1
+
def build_model(self):
- """build model & load weights."""
+ """build the vision part of a VLM model when backend is turbomind, or
+ load the whole VLM model when `self.with_llm==True`"""
check_llava_install()
global _model_path
@@ -105,3 +122,19 @@ def build_model(self):
with init_yi_model(), disable_transformers_logging():
super().build_model()
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """refer to `super().preprocess() for spec."""
+ images = self.collect_images(messages)
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+ pixel_values = process_images([image], self.image_processor,
+ self.config)
+ outputs.append(
+ dict(pixel_values=pixel_values,
+ image_size=image.size,
+ image_tokens=self.n_token_per_image,
+ image_token_id=0))
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
diff --git a/lmdeploy/vl/templates.py b/lmdeploy/vl/templates.py
deleted file mode 100644
index cdf398868a..0000000000
--- a/lmdeploy/vl/templates.py
+++ /dev/null
@@ -1,550 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import asyncio
-from typing import Dict, List, Tuple, Union
-
-import PIL
-import PIL.Image
-
-from lmdeploy.archs import get_model_arch
-from lmdeploy.model import BaseModel
-from lmdeploy.utils import get_logger
-from lmdeploy.vl.constants import IMAGE_TOKEN
-from lmdeploy.vl.utils import load_image
-
-logger = get_logger('lmdeploy')
-
-VLPromptType = Union[str, Tuple[str, PIL.Image.Image],
- Tuple[str, List[PIL.Image.Image]]]
-
-
-class VLChatTemplateWrapper:
- """vl chat template wrapper."""
-
- def __init__(self, chat_template: BaseModel):
- self.chat_template = chat_template
-
- def prompt_to_messages(self, prompt: VLPromptType):
- """convert prompt to GTP4V format."""
- messages = {
- 'role': 'user',
- 'content': [{
- 'type': 'text',
- 'text': '',
- }]
- }
- if isinstance(prompt, str):
- messages['content'][0]['text'] = prompt
- else:
- prompt, images = prompt
- if not isinstance(images, list):
- images = [images]
- messages['content'][0]['text'] = prompt
- for image in images:
- # 'image_url': means url or local path to image.
- # 'image_data': means PIL.Image.Image object.
- if isinstance(image, str):
- image = load_image(image)
- item = {
- 'type': 'image_data',
- 'image_data': {
- 'data': image
- }
- }
- elif isinstance(image, PIL.Image.Image):
- item = {
- 'type': 'image_data',
- 'image_data': {
- 'data': image
- }
- }
- else:
- raise ValueError(
- 'image should be a str(url/path) or PIL.Image.Image')
-
- messages['content'].append(item)
-
- return [messages]
-
- async def async_collect_pil_images(
- self, messages: Dict) -> List[Tuple[PIL.Image.Image, Dict]]:
- """collect image from messages."""
- images_with_kwargs = []
- for message in messages:
- role = message['role']
- content = message['content']
- if role != 'user' or isinstance(content, str):
- continue
- for item in content:
- # 'image_url': means url or local path to image.
- # 'image_data': means PIL.Image.Image object.
- if item['type'] == 'image_url':
- item_copy = item['image_url'].copy()
- try:
- url = item_copy.pop('url')
- images_with_kwargs.append([url, item_copy])
- except KeyError:
- logger.error(f'invalid format {message}')
- elif item['type'] == 'image_data':
- item_copy = item['image_data'].copy()
- try:
- data = item_copy.pop('data')
- images_with_kwargs.append([data, item_copy])
- except KeyError:
- logger.error(f'invalid format {message}')
-
- def _inner_call(i, images):
- url_or_data = images[i][0]
- images[i][0] = load_image(url_or_data)
-
- await asyncio.gather(*[
- asyncio.get_event_loop().run_in_executor(None, _inner_call, i,
- images_with_kwargs)
- for i in range(len(images_with_kwargs))
- ])
-
- return images_with_kwargs
-
- def append_image_token(self, prompt, num_images: int):
- """append image token to user prompt."""
- if IMAGE_TOKEN in prompt:
- return prompt
- return (IMAGE_TOKEN + '\n') * num_images + prompt
-
- def convert_messages(self, messages, sequence_start=True):
- """convert GPT4V message format to GPT4 text format."""
- new_messages = []
- for message in messages:
- role = message['role']
- content = message['content']
- if role != 'user' or isinstance(content, str):
- if isinstance(content, list):
- text = content[0]['text']
- message = {'role': role, 'content': text}
- new_messages.append(message)
- continue
- num_images = 0
- for item in content:
- # 'image_url': means url or local path to image.
- # 'image_data': means PIL.Image.Image object.
- if item['type'] == 'image_url':
- num_images += 1
- elif item['type'] == 'image_data':
- num_images += 1
- elif item['type'] == 'text':
- prompt = item['text']
- if num_images > 0:
- # add IMAGE_TOKEN to user prompt
- prompt = self.append_image_token(prompt, num_images)
- new_item = {'role': 'user', 'content': prompt}
- new_messages.append(new_item)
- return new_messages
-
- def messages2prompt(self, messages, sequence_start=True, **kwargs) -> str:
- """convert messages to decorated prompt."""
- if isinstance(messages, str):
- return self.chat_template.messages2prompt(messages, sequence_start)
- new_messages = self.convert_messages(messages, sequence_start)
- return self.chat_template.messages2prompt(new_messages, sequence_start)
-
-
-class LlavaVLChatTemplateWrapper(VLChatTemplateWrapper):
- """Llava vl chat template."""
- pass
-
-
-class YiVLChatTemplateWrapper(VLChatTemplateWrapper):
- """Yi vl chat template."""
- pass
-
-
-class InternVLChatTemplateWrapper(VLChatTemplateWrapper):
- """InternVL chat template."""
-
- def append_image_token(self, prompt, num_images: int):
- """append image tokens to user prompt."""
- # lmdeploy uses as image token
- # internvl uses special tags
- if IMAGE_TOKEN in prompt and f'{IMAGE_TOKEN}' not in prompt:
- prompt = prompt.replace(f'{IMAGE_TOKEN}',
- f'{IMAGE_TOKEN}')
- prompt = prompt.replace('', '')
- prompt = prompt.replace('', '')
- prompt = prompt.replace('', '')
- elif IMAGE_TOKEN not in prompt:
- prompt = f'{IMAGE_TOKEN * num_images}\n' + prompt
- return prompt
-
-
-class DeepSeekVLChatTemplateWrapper(VLChatTemplateWrapper):
- """DeepSeek vl chat template."""
-
- def append_image_token(self, prompt, num_images: int):
- """append image tokens to user prompt."""
- if IMAGE_TOKEN in prompt:
- return prompt
- logger.error(
- f'for deepseek-vl model, the user should insert the {IMAGE_TOKEN} '
- 'to user prompt manually, please read https://lmdeploy.readthedocs'
- '.io/en/latest/inference/vl_pipeline.html for more details.')
- if num_images == 1:
- return f'{IMAGE_TOKEN}{prompt}'
- res = ''
- for i in range(num_images):
- res += f'{IMAGE_TOKEN} is Figure {str(i)}.\n'
- res = res + prompt
- return res
-
-
-class QwenVLChatTemplateWrapper(VLChatTemplateWrapper):
- """Qwen vl chat template."""
-
- def append_image_token(self, prompt, num_images: int):
- """append image tokens to user prompt."""
- if IMAGE_TOKEN in prompt:
- return prompt
- res = ''
- for i in range(num_images):
- res += f'Picture {str(i)}:{IMAGE_TOKEN}\n'
- res = res + prompt
- return res
-
-
-class Qwen2VLChatTemplateWrapper(VLChatTemplateWrapper):
- """qwen2 vl."""
-
- def append_image_token(self, prompt, num_images: int):
- """append image tokens to user prompt."""
- if IMAGE_TOKEN in prompt and '<|vision_start|>' not in prompt:
- prompt = prompt.replace(
- IMAGE_TOKEN, f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>')
- else:
- # Qwen2-VL-2B-Instruct will concat image and user prompt according
- # to their order in the content list
- # we insert image token before user prompt by default. The user can
- # use custom image token position if they want the same decorated
- # prompt as Qwen2-VL
- prompt = f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>' * \
- num_images + prompt
- return prompt
-
- def get_mrope_info(self,
- seq_len: int,
- grid_thws: List[Tuple[int, int, int]] = None,
- embedding_ranges: List[Tuple[int, int]] = None):
- import torch
- if grid_thws is None:
- mrope_position_ids = torch.arange(seq_len).expand(3, -1)
- mrope_position_delta = torch.tensor([0], dtype=torch.long)
- else:
- mrope_position_ids = [
- torch.arange(embedding_ranges[0][0]).expand(3, -1)
- ]
- st_idx = embedding_ranges[0][0]
- for i, (grid_thw, embedding_range) in enumerate(
- zip(grid_thws, embedding_ranges)):
- llm_grid_t, llm_grid_h, llm_grid_w = grid_thw
- llm_grid_h //= 2
- llm_grid_w //= 2
- t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
- -1, llm_grid_h * llm_grid_w).flatten()
- h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
- llm_grid_t, -1, llm_grid_w).flatten()
- w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
- llm_grid_t, llm_grid_h, -1).flatten()
- mrope_position_ids.append(
- torch.stack([t_index, h_index, w_index]) + st_idx)
- st_idx += max(llm_grid_h, llm_grid_w)
- if i < len(embedding_ranges) - 1:
- text_len = embedding_ranges[i +
- 1][0] - embedding_ranges[i][1]
- else:
- text_len = seq_len - embedding_range[1]
- mrope_position_ids.append(
- torch.arange(text_len).expand(3, -1) + st_idx)
- st_idx += text_len
- mrope_position_ids = torch.cat(mrope_position_ids, dim=-1)
- mrope_position_delta = torch.tensor([st_idx - seq_len],
- dtype=torch.long)
-
- return mrope_position_ids, mrope_position_delta
-
-
-class CogVLMChatTemplateWrapper(VLChatTemplateWrapper):
- """cogvlm chat template wrapper."""
-
- def __init__(self, chat_template: BaseModel):
- from lmdeploy.model import Vicuna
- self.chat_template = chat_template
- self.llm_chat_template = Vicuna(eoa=chat_template.eoa,
- stop_words=chat_template.stop_words)
-
- def convert_messages(self, messages, sequence_start=True):
- """convert GPT4V message format to GPT4 text format."""
- new_messages = []
- for message in messages:
- role = message['role']
- content = message['content']
- if role != 'user' or isinstance(content, str):
- new_messages.append(message)
- continue
- num_images = 0
- for item in content:
- if item['type'] == 'image_url':
- num_images += 1
- elif item['type'] == 'image_data':
- num_images += 1
- elif item['type'] == 'text':
- prompt = item['text']
-
- new_item = {
- 'role': 'user',
- 'content': prompt,
- 'num_images': num_images
- }
- new_messages.append(new_item)
- return new_messages
-
- def messages2prompt(self, messages, sequence_start=True, **kwargs) -> str:
- """convert messages to decorated prompt."""
- if isinstance(messages, str):
- return self.chat_template.messages2prompt(messages, sequence_start)
- new_messages = self.convert_messages(messages, sequence_start)
- prompt = ''
- for i, msg in enumerate(new_messages):
- num_images = msg.pop('num_images', 0)
- if num_images == 0:
- role = msg['role']
- msg = self.llm_chat_template.messages2prompt([msg],
- sequence_start
- and i == 0)
- msg = dict(role=role, content=msg)
- prompt_i = self.chat_template.messages2prompt([msg], sequence_start
- and i == 0)
- if num_images > 0:
- prompt_i = (IMAGE_TOKEN * num_images) + prompt_i
- prompt += prompt_i
- return prompt
-
-
-class InternLMXComposer2TemplateWrapper(VLChatTemplateWrapper):
- """InternLM-XComposer2 chat template."""
-
- def append_image_token(self, prompt, num_images: int):
- if IMAGE_TOKEN in prompt:
- return prompt
- logger.warning(f'auto append {IMAGE_TOKEN} at the beginning, '
- 'the user can manually insert the token to prompt')
- return ' '.join([IMAGE_TOKEN] * num_images) + prompt
-
-
-class MiniGeminiLlamaTempateWrapper(VLChatTemplateWrapper):
- """Qwen vl chat template."""
-
- def append_image_token(self, prompt, num_images: int):
- """append image tokens to user prompt."""
- if num_images == 0:
- return prompt
- if IMAGE_TOKEN in prompt:
- return prompt
- res = f'{IMAGE_TOKEN}\n'
- assert num_images <= 1, 'MiniGeminiLlama accepts 1 input image'
- res = res + prompt
- return res
-
-
-class MllamaTempateWrapper(VLChatTemplateWrapper):
- """Mllama chat template."""
-
- def append_image_token(self, prompt, num_images: int):
- """append image tokens to user prompt."""
- return f'{IMAGE_TOKEN * num_images}{prompt}'
-
-
-class MiniCPMVTempateWrapper(VLChatTemplateWrapper):
- """MiniCPM-Llama3-V-2_5 chat template."""
-
- def append_image_token(self, prompt, num_images: int):
- if IMAGE_TOKEN in prompt:
- return prompt
- prompt = f'{IMAGE_TOKEN}\n' * num_images + prompt
- return prompt
-
- def update_image_token(self, prompt, features):
- _features = []
- _prompt = []
- segs = prompt.split(f'{IMAGE_TOKEN}\n')
- for i, seg in enumerate(segs):
- if i > 0 and i <= len(features):
- _feat = features[i - 1]['embeddings'].split(1)
- _feat = [x.squeeze() for x in _feat]
- _features.extend(_feat)
- _seg = f'{IMAGE_TOKEN}'
- if len(_feat) > 1:
- grid = features[i - 1]['grid']
- if grid is not None:
- _slice = '\n'.join(
- [f'{IMAGE_TOKEN}' * grid[0]] *
- grid[1])
- _seg = f'{_seg}{_slice}\n'
- _prompt.append(_seg)
- _prompt.append(seg)
- _prompt = ''.join(_prompt)
- return _prompt, _features
-
-
-class MiniCPMV26TempateWrapper(MiniCPMVTempateWrapper):
- """MiniCPM-V-2_6 chat template."""
-
- def update_image_token(self, prompt, features):
- _features = []
- _prompt = []
- segs = prompt.split(f'{IMAGE_TOKEN}\n')
- idx = 0
- for i, seg in enumerate(segs):
- if i > 0 and i <= len(features):
- _feat = features[i - 1]['embeddings'].split(1)
- _feat = [x.squeeze() for x in _feat]
- _features.extend(_feat)
- _seg = f'{IMAGE_TOKEN}'
- if features[i - 1].get('use_image_id', False):
- _seg = f'{idx}' + _seg
- idx += 1
- if len(_feat) > 1:
- grid = features[i - 1]['grid']
- if grid is not None:
- _slice = '\n'.join(
- [f'{IMAGE_TOKEN}' * grid[0]] *
- grid[1])
- _seg = _seg + _slice
- _seg += '\n'
- _prompt.append(_seg)
- _prompt.append(seg)
- _prompt = ''.join(_prompt)
- return _prompt, _features
-
-
-class GLM4VChatTemplateWrapper(VLChatTemplateWrapper):
- """glm-4v chat template."""
- pass
-
-
-class MolmoChatTemplateWrapper(VLChatTemplateWrapper):
-
- async def async_collect_pil_images(
- self, messages: List[Dict]) -> List[Tuple[PIL.Image.Image, Dict]]:
- """collect images from messages.
-
- Args:
- messages (List[Dict]): a user request of GPT4V message format
- """
- if isinstance(messages, Dict):
- messages = [messages]
- assert isinstance(messages, List)
-
- out_messages = [None] * len(messages)
-
- def _inner_call(i, in_messages, out_messages):
- role = in_messages[i]['role']
- content = in_messages[i]['content']
- if role != 'user' or isinstance(content, str):
- # means message is user's prompt input or assistant's prompt,
- # returning it directory
- out_messages[i] = in_messages[i]
- return
- # the role is a user and the content is a list
- assert isinstance(content, List)
- message = dict(role=role, content='', images=[])
- for item in content:
- # 'image_url': means url or local path to image.
- # 'image_data': means PIL.Image.Image object.
- if item['type'] == 'image_url':
- try:
- image = load_image(item['image_url']['url'])
- message['images'].append(image)
- except KeyError:
- logger.error(f'invalid format {message}')
- elif item['type'] == 'image_data':
- try:
- image = load_image(item['image_data']['data'])
- message['images'].append(image)
- except KeyError:
- logger.error(f'invalid format {message}')
- elif item['type'] == 'text':
- message['content'] = item['text']
- else:
- logger.error(f'unexpected content type {message}')
- out_messages[i] = message
-
- await asyncio.gather(*[
- asyncio.get_event_loop().run_in_executor(None, _inner_call, i,
- messages, out_messages)
- for i in range(len(messages))
- ])
- return [(None, out_messages)]
-
- def messages2prompt(self, messages, sequence_start=True, **kwargs) -> str:
- """Return a placeholder "IMAGE_TOKEN" so that
- `vl_asyn_engine._get_prompt_input` can know that it."""
- if isinstance(messages, str):
- return self.chat_template.messages2prompt(messages, sequence_start)
- else:
- _messages = []
- for message in messages:
- role, content = message['role'], message['content']
- if role != 'user' or isinstance(content, str):
- _messages.append(message)
- continue
- for item in content:
- item_type = item['type']
- if item_type in ['image_url', 'image_data']:
- # Return the image placeholder so that
- # `vl_asyn_engine._get_prompt_input` can know that the
- # request contains images
- return IMAGE_TOKEN
- _messages.append(dict(role=role, content=item[item_type]))
- return self.chat_template.messages2prompt(_messages,
- sequence_start)
-
-
-def get_vl_prompt_template(model_path: str, chat_template: BaseModel,
- model_name: str) -> VLChatTemplateWrapper:
- """get vision language prompt template."""
- assert type(chat_template) != type(BaseModel()), 'failed to match ' \
- 'chat template, please explicit set chat_template_config' # noqa E721
- if model_name == 'yi-vl':
- return YiVLChatTemplateWrapper(chat_template)
- arch, cfg = get_model_arch(model_path)
- if arch == 'QWenLMHeadModel':
- return QwenVLChatTemplateWrapper(chat_template)
- elif arch in [
- 'LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM',
- 'LlavaForConditionalGeneration',
- 'LlavaNextForConditionalGeneration', 'Phi3VForCausalLM'
- ]:
- return LlavaVLChatTemplateWrapper(chat_template)
- elif arch == 'MultiModalityCausalLM': # deepseek-vl
- return DeepSeekVLChatTemplateWrapper(chat_template)
- elif arch == 'MllamaForConditionalGeneration': # llama 3.2
- return MllamaTempateWrapper(chat_template)
- elif arch == 'CogVLMForCausalLM':
- return CogVLMChatTemplateWrapper(chat_template)
- elif arch in ['InternLMXComposer2ForCausalLM', 'InternLM2ForCausalLM']:
- return InternLMXComposer2TemplateWrapper(chat_template)
- elif arch == 'InternVLChatModel':
- return InternVLChatTemplateWrapper(chat_template)
- elif arch in ['MiniGeminiLlamaForCausalLM', 'MGMLlamaForCausalLM']:
- return MiniGeminiLlamaTempateWrapper(chat_template)
- elif arch == 'MiniCPMV':
- version_map = {
- '2.5': MiniCPMVTempateWrapper,
- '2.6': MiniCPMV26TempateWrapper
- }
- version = str(getattr(cfg, 'version', '2.5'))
- return version_map[version](chat_template)
- elif arch == 'ChatGLMModel':
- return GLM4VChatTemplateWrapper(chat_template)
- elif arch == 'Qwen2VLForConditionalGeneration':
- return Qwen2VLChatTemplateWrapper(chat_template)
- elif arch == 'MolmoForCausalLM':
- return MolmoChatTemplateWrapper(chat_template)
- raise ValueError(f'unsupported vl_prompt_template with arch {arch}')
diff --git a/requirements/runtime_ascend.txt b/requirements/runtime_ascend.txt
index 05d74bbe72..81f538275c 100644
--- a/requirements/runtime_ascend.txt
+++ b/requirements/runtime_ascend.txt
@@ -1,5 +1,5 @@
accelerate>=0.29.3
-dlinfer-ascend>=0.1.2
+dlinfer-ascend>=0.1.3
einops
fastapi
fire
@@ -16,7 +16,8 @@ safetensors
sentencepiece
shortuuid
tiktoken
-torch<=2.4.0,>=2.0.0
-torchvision<=0.19.0,>=0.15.0
+torch<=2.4.0,>=2.3.1
+torch-npu==2.3.1
+torchvision<=0.19.0,>=0.18.1
transformers
uvicorn
diff --git a/requirements/runtime_cuda.txt b/requirements/runtime_cuda.txt
new file mode 100644
index 0000000000..41af6039bd
--- /dev/null
+++ b/requirements/runtime_cuda.txt
@@ -0,0 +1,22 @@
+accelerate>=0.29.3
+einops
+fastapi
+fire
+mmengine-lite
+numpy<2.0.0
+openai
+outlines<0.1.0
+peft<=0.11.1
+pillow
+protobuf
+pydantic>2.0.0
+pynvml
+safetensors
+sentencepiece
+shortuuid
+tiktoken
+torch<=2.5.1,>=2.0.0
+torchvision<=0.20.1,>=0.15.0
+transformers
+triton==3.0.0; sys_platform == "linux"
+uvicorn
diff --git a/requirements/runtime.txt b/requirements/runtime_maca.txt
similarity index 78%
rename from requirements/runtime.txt
rename to requirements/runtime_maca.txt
index 400c492b09..f65b3827cd 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime_maca.txt
@@ -1,4 +1,4 @@
-accelerate>=0.29.3
+accelerate==0.32.1
einops
fastapi
fire
@@ -18,5 +18,5 @@ tiktoken
torch<=2.4.0,>=2.0.0
torchvision<=0.19.0,>=0.15.0
transformers
-triton>=2.2.0,<=3.0.0; sys_platform == "linux"
+triton>=2.1.0; sys_platform == "linux"
uvicorn
diff --git a/requirements.txt b/requirements_cuda.txt
similarity index 70%
rename from requirements.txt
rename to requirements_cuda.txt
index 91d38808f1..7c1d387dfb 100644
--- a/requirements.txt
+++ b/requirements_cuda.txt
@@ -1,4 +1,4 @@
-r requirements/build.txt
--r requirements/runtime.txt
+-r requirements/runtime_cuda.txt
-r requirements/lite.txt
-r requirements/serve.txt
diff --git a/requirements_maca.txt b/requirements_maca.txt
new file mode 100644
index 0000000000..075b132c8c
--- /dev/null
+++ b/requirements_maca.txt
@@ -0,0 +1,4 @@
+-r requirements/build.txt
+-r requirements/runtime_maca.txt
+-r requirements/lite.txt
+-r requirements/serve.txt
diff --git a/setup.py b/setup.py
index 7a08ac7919..52e180d8a2 100644
--- a/setup.py
+++ b/setup.py
@@ -4,18 +4,14 @@
from setuptools import find_packages, setup
-npu_available = False
-try:
- import torch_npu
-
- npu_available = torch_npu.npu.is_available()
-except ImportError:
- pass
-
pwd = os.path.dirname(__file__)
version_file = 'lmdeploy/version.py'
+def get_target_device():
+ return os.getenv('LMDEPLOY_TARGET_DEVICE', 'cuda')
+
+
def readme():
with open(os.path.join(pwd, 'README.md'), encoding='utf-8') as f:
content = f.read()
@@ -154,16 +150,12 @@ def gen_packages_items():
setup_requires=parse_requirements('requirements/build.txt'),
tests_require=parse_requirements('requirements/test.txt'),
install_requires=parse_requirements(
- 'requirements/runtime_ascend.txt'
- if npu_available else 'requirements/runtime.txt'),
+ f'requirements/runtime_{get_target_device()}.txt'),
extras_require={
'all':
- parse_requirements('requirements_ascend.txt'
- if npu_available else 'requirements.txt'),
- 'lite':
- parse_requirements('requirements/lite.txt'),
- 'serve':
- parse_requirements('requirements/serve.txt')
+ parse_requirements(f'requirements_{get_target_device()}.txt'),
+ 'lite': parse_requirements('requirements/lite.txt'),
+ 'serve': parse_requirements('requirements/serve.txt')
},
has_ext_modules=check_ext_modules,
classifiers=[
diff --git a/tests/pytorch/kernel/test_flash_attention.py b/tests/pytorch/kernel/test_flash_attention.py
index 7d4b7a7f3a..e56de44b37 100644
--- a/tests/pytorch/kernel/test_flash_attention.py
+++ b/tests/pytorch/kernel/test_flash_attention.py
@@ -10,20 +10,26 @@ def _conti_input(data, q_seqlens):
return data
-def _make_bias(q_seqlens, history_lens, neg_val):
- full_seq_lens = q_seqlens + history_lens
+def _make_bias(q_seqlens, history_lens, neg_val, causal):
+ kv_seqlens = q_seqlens + history_lens
max_seq_len = q_seqlens.max().item()
- max_full_len = full_seq_lens.max().item()
- seq_ranges = [torch.arange(max_seq_len) for _ in q_seqlens]
- for r, l in zip(seq_ranges, q_seqlens):
- r[l:] = -max_full_len
- seq_ranges = torch.stack(seq_ranges, dim=0).cuda()
- kv_ranges = [torch.arange(max_full_len) for _ in full_seq_lens]
- kv_ranges = torch.stack(kv_ranges, 0).cuda()
- mask = kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:,
- None,
- None]
- return mask.float() * neg_val
+ max_kv_len = kv_seqlens.max().item()
+ if causal:
+ seq_ranges = [torch.arange(max_seq_len) for _ in q_seqlens]
+ for r, l in zip(seq_ranges, q_seqlens):
+ r[l:] = -max_kv_len
+ seq_ranges = torch.stack(seq_ranges, dim=0).cuda()
+ kv_ranges = [torch.arange(max_kv_len) for _ in kv_seqlens]
+ kv_ranges = torch.stack(kv_ranges, 0).cuda()
+ mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] >
+ history_lens[:, None, None])
+ return mask.float() * neg_val
+ else:
+ q_mask = torch.arange(max_seq_len)[None].cuda() < q_seqlens[:, None]
+ k_mask = torch.arange(max_kv_len)[None].cuda() < kv_seqlens[:, None]
+ mask = q_mask[:, :, None] & k_mask[:, None, :]
+
+ return (~mask).float() * neg_val
def _naive_attention(batched_q, batched_kv, bias):
@@ -100,6 +106,10 @@ def num_heads_q(self, request):
def num_heads_k(self, request):
yield request.param
+ @pytest.fixture
+ def causal(self, request):
+ yield request.param
+
@pytest.fixture
def q_seqlens(self, request):
yield torch.tensor(request.param, device='cuda')
@@ -138,8 +148,8 @@ def batched_kv(self, q_seqlens, history_lens, num_heads_k, head_dim_k,
head_dim_v, dtype):
torch.manual_seed(123)
batch_size = len(q_seqlens)
- full_seq_lens = q_seqlens + history_lens
- max_seq_len = full_seq_lens.max().item()
+ kv_seqlens = q_seqlens + history_lens
+ max_seq_len = kv_seqlens.max().item()
k = torch.rand(batch_size,
max_seq_len,
num_heads_k,
@@ -167,9 +177,9 @@ def conti_kv(self, kv_seqlens, batched_kv):
yield (conti_k, conti_v)
@pytest.fixture
- def mask(self, q_seqlens, history_lens):
+ def mask(self, q_seqlens, history_lens, causal):
neg_val = -1e30
- yield _make_bias(q_seqlens, history_lens, neg_val)
+ yield _make_bias(q_seqlens, history_lens, neg_val, causal)
@pytest.fixture
def gt(self, batched_q, batched_kv, mask):
@@ -183,11 +193,13 @@ def conti_gt(self, gt, q_seqlens):
@pytest.mark.parametrize('head_dim_v', [32], indirect=True)
@pytest.mark.parametrize('num_heads_q', [8, 2], indirect=True)
@pytest.mark.parametrize('num_heads_k', [2], indirect=True)
+ @pytest.mark.parametrize('causal', [True, False], indirect=True)
@pytest.mark.parametrize(['q_seqlens', 'history_lens'],
[([30, 50, 70, 90], [50, 40, 30, 20])],
indirect=True)
def test_flash_attention(self, conti_q, conti_kv, q_start_loc, q_seqlens,
- kv_start_loc, kv_seqlens, head_dim_v, conti_gt):
+ kv_start_loc, kv_seqlens, head_dim_v, causal,
+ conti_gt):
from lmdeploy.pytorch.kernels.cuda.flashattention import \
flash_attention_fwd
max_seq_len = q_seqlens.max().item()
@@ -202,7 +214,8 @@ def test_flash_attention(self, conti_q, conti_kv, q_start_loc, q_seqlens,
q_seqlens=q_seqlens,
kv_start_loc=kv_start_loc,
kv_seqlens=kv_seqlens,
- max_seqlen=max_seq_len)
+ max_seqlen=max_seq_len,
+ causal=causal)
torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5)
@pytest.fixture
diff --git a/tests/test_lmdeploy/test_model.py b/tests/test_lmdeploy/test_model.py
index 3b78053a74..0e53283a87 100644
--- a/tests/test_lmdeploy/test_model.py
+++ b/tests/test_lmdeploy/test_model.py
@@ -220,7 +220,7 @@ def test_llama3_1():
},
}]
actual_prompt = model.messages2prompt(messages, tools=tools)
- expected_prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 23 Jul 2024\n\n# Tool Instructions\n- Always execute python code in messages that you share.\n- When looking for real time information use relevant functions if available else fallback to brave_search\n\n\n\nYou have access to the following functions:\n\nUse the function \'spotify_trending_songs\' to: Get top trending songs on Spotify\n{"name": "spotify_trending_songs", "description": "Get top trending songs on Spotify", "parameters": {"n": {"param_type": "int", "description": "Number of trending songs to get", "required": true}}}\n\n\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => ` a JSON dict with the function argument name as key and function argument value as value.\nend_tag => ``\n\nHere is an example,\n{"example_name": "example_value"}\n\nReminder:\n- Function calls MUST follow the specified format\n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line"\n- Always add your sources when using search results to answer the user query\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCan you check the top 5 trending songs on spotify?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa
+ expected_prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n# Tool Instructions\n- Always execute python code in messages that you share.\n- When looking for real time information use relevant functions if available else fallback to brave_search\n\n\n\nYou have access to the following functions:\n\nUse the function \'spotify_trending_songs\' to: Get top trending songs on Spotify\n{"name": "spotify_trending_songs", "description": "Get top trending songs on Spotify", "parameters": {"n": {"param_type": "int", "description": "Number of trending songs to get", "required": true}}}\n\n\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => ` a JSON dict with the function argument name as key and function argument value as value.\nend_tag => ``\n\nHere is an example,\n{"example_name": "example_value"}\n\nReminder:\n- Function calls MUST follow the specified format\n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line"\n- Always add your sources when using search results to answer the user query\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCan you check the top 5 trending songs on spotify?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa
assert actual_prompt == expected_prompt
diff --git a/tests/test_lmdeploy/test_vl_encode.py b/tests/test_lmdeploy/test_vl/test_vl_encode.py
similarity index 100%
rename from tests/test_lmdeploy/test_vl_encode.py
rename to tests/test_lmdeploy/test_vl/test_vl_encode.py
diff --git a/tests/test_lmdeploy/test_vl_template.py b/tests/test_lmdeploy/test_vl_template.py
deleted file mode 100644
index cf8abf9e44..0000000000
--- a/tests/test_lmdeploy/test_vl_template.py
+++ /dev/null
@@ -1,132 +0,0 @@
-import PIL
-
-from lmdeploy.model import MODELS
-from lmdeploy.vl.constants import IMAGE_TOKEN
-from lmdeploy.vl.templates import VLChatTemplateWrapper
-
-
-def test_prompt_to_messages():
- model = MODELS.get('llava-v1')()
- templtae = VLChatTemplateWrapper(model)
- out = templtae.prompt_to_messages('hi')
- assert isinstance(out, list) and isinstance(out[0], dict)
- im = PIL.Image.new(mode='RGB', size=(200, 200))
- out = templtae.prompt_to_messages(('hi', [im]))
- assert isinstance(out, list) and isinstance(out[0], dict)
-
-
-def test_messages2prompt():
- model = MODELS.get('llava-v1')()
- templtae = VLChatTemplateWrapper(model)
- messages = [
- dict(role='user',
- content=[
- dict(type='text', text='q1'),
- dict(type='image_url', image_url=dict(url='xxx'))
- ])
- ]
- prompt = templtae.messages2prompt(messages)
- assert isinstance(prompt, str)
- assert prompt.count(IMAGE_TOKEN) == 1
- expected = (
- 'A chat between a curious human and an artificial intelligence '
- 'assistant. The assistant gives helpful, detailed, and polite '
- "answers to the human's questions. USER: "
- '\nq1 ASSISTANT:')
- assert prompt == expected
-
- messages.append({'role': 'assistant', 'content': 'a1'})
- messages.append({'role': 'user', 'content': 'q2'})
- prompt = templtae.messages2prompt(messages)
- expected = (
- 'A chat between a curious human and an artificial intelligence '
- 'assistant. The assistant gives helpful, detailed, and polite '
- "answers to the human's questions. USER: "
- '\nq1 ASSISTANT: a1USER: q2 ASSISTANT:')
- assert prompt == expected
-
-
-def test_internvl2_conv():
- # https://huggingface.co/OpenGVLab/InternVL2-8B/blob/3bfd3664dea4f3da628785f5125d30f889701253/conversation.py
- from transformers.dynamic_module_utils import get_class_from_dynamic_module
- get_conv_template = get_class_from_dynamic_module(
- 'conversation.get_conv_template', 'OpenGVLab/InternVL2-8B')
- template = get_conv_template('internlm2-chat')
- question1 = 'question1'
- template.append_message(template.roles[0], question1)
- template.append_message(template.roles[1], None)
- model = MODELS.get('internvl2-internlm2')()
- messages = [dict(role='user', content=question1)]
- assert template.get_prompt() == model.messages2prompt(messages)
-
- answer1 = 'answer1'
- template.messages[-1][1] = answer1
- question2 = 'question2'
- template.append_message(template.roles[0], question2)
- template.append_message(template.roles[1], None)
- messages.append(dict(role='assistant', content=answer1))
- messages.append(dict(role='user', content=question2))
- assert template.get_prompt() == model.messages2prompt(messages)
-
-
-def test_llava_conv_chatml_direct():
- model = MODELS.get('llava-chatml')()
- templtae = VLChatTemplateWrapper(model)
- messages = [
- dict(role='user',
- content=[
- dict(type='text', text='q1'),
- dict(type='image_url', image_url=dict(url='xxx'))
- ])
- ]
-
- prompt = templtae.messages2prompt(messages)
- expected = ('<|im_start|>system\nAnswer the questions.<|im_end|>'
- '<|im_start|>user\n\nq1<|im_end|>'
- '<|im_start|>assistant\n')
- assert prompt == expected
-
- messages.append({'role': 'assistant', 'content': 'a1'})
- messages.append({'role': 'user', 'content': 'q2'})
- prompt = templtae.messages2prompt(messages)
- expected = ('<|im_start|>system\nAnswer the questions.<|im_end|>'
- '<|im_start|>user\n\nq1<|im_end|>'
- '<|im_start|>assistant\na1<|im_end|>'
- '<|im_start|>user\nq2<|im_end|>'
- '<|im_start|>assistant\n')
- assert prompt == expected
-
-
-def test_custom_image_token():
- from lmdeploy.vl.templates import DeepSeekVLChatTemplateWrapper
- model = MODELS.get('deepseek-vl')()
- template = DeepSeekVLChatTemplateWrapper(model)
-
- def create_user(query: str):
- item = dict(role='user', content=[dict(type='text', text=query)])
- num = query.count(IMAGE_TOKEN)
- for _ in range(num):
- item['content'].append(
- dict(type='image_url', image_url=dict(url='xxx')))
- return item
-
- def create_assistant(response: str):
- return dict(role='assistant', content=response)
-
- messages = [create_user(f'{IMAGE_TOKEN} q1')]
- prompt = template.messages2prompt(messages)
- expected = ('You are a helpful language and vision assistant. You are able'
- ' to understand the visual content that the user provides, and'
- ' assist the user with a variety of tasks using natural '
- 'language.\n\nUser: q1\n\nAssistant:')
- assert prompt == expected
-
- messages.append(create_assistant('a1'))
- messages.append(create_user(f'q2 {IMAGE_TOKEN}'))
- prompt = template.messages2prompt(messages)
- expected = ('You are a helpful language and vision assistant. You are able'
- ' to understand the visual content that the user provides, and'
- ' assist the user with a variety of tasks using natural '
- 'language.\n\nUser: q1\n\nAssistant: '
- 'a1<|end▁of▁sentence|>User: q2 \n\nAssistant:')
- assert prompt == expected