From ee5f34b1c2c71b2d56054a5ca23fe1c50c1458bb Mon Sep 17 00:00:00 2001 From: Daniele <36171005+dtrifiro@users.noreply.github.com> Date: Mon, 23 Sep 2024 18:44:26 +0200 Subject: [PATCH 01/50] [CI/Build] use setuptools-scm to set __version__ (#4738) Co-authored-by: youkaichao --- .gitignore | 3 + Dockerfile | 5 +- Dockerfile.cpu | 4 +- Dockerfile.neuron | 23 +++---- Dockerfile.openvino | 5 +- Dockerfile.ppc64le | 12 +++- Dockerfile.rocm | 9 ++- Dockerfile.tpu | 17 ++++-- Dockerfile.xpu | 13 ++-- .../getting_started/cpu-installation.rst | 2 +- pyproject.toml | 10 +++- requirements-build.txt | 3 +- setup.py | 60 ++++--------------- tests/test_embedded_commit.py | 7 ++- vllm/__init__.py | 4 +- vllm/version.py | 12 ++-- 16 files changed, 94 insertions(+), 95 deletions(-) diff --git a/.gitignore b/.gitignore index bc7236ea18698..43eb89cacc0a5 100644 --- a/.gitignore +++ b/.gitignore @@ -196,5 +196,8 @@ _build/ *_hip* hip_compat.h +# version file generated by setuptools-scm +/vllm/_version.py + # Benchmark dataset benchmarks/*.json diff --git a/Dockerfile b/Dockerfile index 30e27620574a0..ec803764a128d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -79,15 +79,13 @@ ENV MAX_JOBS=${max_jobs} ARG nvcc_threads=8 ENV NVCC_THREADS=$nvcc_threads -ARG buildkite_commit -ENV BUILDKITE_COMMIT=${buildkite_commit} - ARG USE_SCCACHE ARG SCCACHE_BUCKET_NAME=vllm-build-sccache ARG SCCACHE_REGION_NAME=us-west-2 ARG SCCACHE_S3_NO_CREDENTIALS=0 # if USE_SCCACHE is set, use sccache to speed up compilation RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ if [ "$USE_SCCACHE" = "1" ]; then \ echo "Installing sccache..." \ && curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \ @@ -107,6 +105,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ if [ "$USE_SCCACHE" != "1" ]; then \ python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \ fi diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 4d7289366296b..a9d97a3e0bde4 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -62,8 +62,10 @@ ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512} RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/ccache \ + --mount=type=bind,source=.git,target=.git \ VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \ - pip install dist/*.whl + pip install dist/*.whl && \ + rm -rf dist WORKDIR /workspace/ diff --git a/Dockerfile.neuron b/Dockerfile.neuron index 647ed99a41e70..adae6db87ba87 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -6,9 +6,12 @@ FROM $BASE_IMAGE RUN echo "Base image is $BASE_IMAGE" # Install some basic utilities -RUN apt-get update \ - && apt-get install python3 python3-pip -y \ - && apt-get install -y ffmpeg libsm6 libxext6 libgl1 +RUN apt-get update && \ + apt-get install -y \ + git \ + python3 \ + python3-pip \ + ffmpeg libsm6 libxext6 libgl1 ### Mount Point ### # When launching the container, mount the code directory to /app @@ -22,17 +25,17 @@ RUN python3 -m pip install sentencepiece transformers==4.36.2 -U RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U RUN python3 -m pip install --pre neuronx-cc==2.15.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U -COPY ./vllm /app/vllm/vllm -COPY ./setup.py /app/vllm/setup.py -COPY ./requirements-common.txt /app/vllm/requirements-common.txt -COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt +COPY . /app/vllm RUN cd /app/vllm \ - && python3 -m pip install -U -r requirements-neuron.txt + && python3 -m pip install -U \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + -r requirements-neuron.txt ENV VLLM_TARGET_DEVICE neuron -RUN cd /app/vllm \ - && pip install -e . \ +RUN --mount=type=bind,source=.git,target=.git \ + cd /app/vllm \ + && pip install --no-build-isolation -v -e . \ && cd .. CMD ["/bin/bash"] diff --git a/Dockerfile.openvino b/Dockerfile.openvino index 96b9593a2bfa8..95714a3d17188 100644 --- a/Dockerfile.openvino +++ b/Dockerfile.openvino @@ -4,8 +4,9 @@ FROM ubuntu:22.04 AS dev RUN apt-get update -y && \ - apt-get install -y python3-pip git && \ - apt-get install -y ffmpeg libsm6 libxext6 libgl1 + apt-get install -y \ + git python3-pip \ + ffmpeg libsm6 libxext6 libgl1 WORKDIR /workspace # copy requirements diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index 3313162bf28e1..1f374b01b9bc0 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -16,9 +16,15 @@ COPY ./ /workspace/vllm WORKDIR /workspace/vllm # These packages will be in rocketce eventually -RUN pip install -v cmake xformers torch==2.3.1 uvloop==0.20.0 -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing - -RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + torch==2.3.1 \ + -r requirements-cpu.txt \ + xformers uvloop==0.20.0 + +RUN --mount=type=bind,source=.git,target=.git \ + VLLM_TARGET_DEVICE=cpu python3 setup.py install WORKDIR /workspace/ diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 33423fde4ff96..a12d5ba5fd8f5 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -51,13 +51,15 @@ RUN python3 -m pip install --upgrade pip # TODO: implement sccache support across components RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)" # Install torch == 2.5.0 on ROCm -RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ +RUN --mount=type=cache,target=/root/.cache/pip \ + case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ *"rocm-6.1"*) \ python3 -m pip uninstall -y torch torchvision \ - && python3 -m pip install --no-cache-dir --pre \ + && python3 -m pip install --pre \ torch==2.5.0.dev20240726 \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ torchvision==0.20.0.dev20240726 \ - --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \ + --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.1 ;; \ *) ;; esac ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer @@ -138,6 +140,7 @@ ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 ENV TOKENIZERS_PARALLELISM=false RUN --mount=type=cache,target=${CCACHE_DIR} \ + --mount=type=bind,source=.git,target=.git \ --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -Ur requirements-rocm.txt \ && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ diff --git a/Dockerfile.tpu b/Dockerfile.tpu index 04cd4d79f4045..d8f1a42c45177 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -5,16 +5,25 @@ FROM $BASE_IMAGE WORKDIR /workspace # Install some basic utilities -RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 libgl1 +RUN apt-get update && apt-get install -y \ + git \ + ffmpeg libsm6 libxext6 libgl1 # Install the TPU and Pallas dependencies. -RUN python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html -RUN python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html # Build vLLM. COPY . /workspace/vllm ENV VLLM_TARGET_DEVICE="tpu" -RUN cd /workspace/vllm && python3 -m pip install -r requirements-tpu.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ + cd /workspace/vllm && \ + python3 -m pip install \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + -r requirements-tpu.txt RUN cd /workspace/vllm && python3 setup.py develop CMD ["/bin/bash"] diff --git a/Dockerfile.xpu b/Dockerfile.xpu index 8f61e4c55260e..8471edd16e4bb 100644 --- a/Dockerfile.xpu +++ b/Dockerfile.xpu @@ -7,15 +7,20 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \ chmod 644 /usr/share/keyrings/intel-graphics.gpg -RUN apt-get update -y \ -&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1 +RUN apt-get update -y && \ + apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1 COPY ./ /workspace/vllm WORKDIR /workspace/vllm -RUN pip install -v -r requirements-xpu.txt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -v --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + -r requirements-xpu.txt -RUN VLLM_TARGET_DEVICE=xpu python3 setup.py install +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ + VLLM_TARGET_DEVICE=xpu python3 setup.py install CMD ["/bin/bash"] diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index 816e0a29ef28b..c8947beb34942 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -56,7 +56,7 @@ Build from source .. code-block:: console $ pip install --upgrade pip - $ pip install wheel packaging ninja "setuptools>=49.4.0" numpy + $ pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy $ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu - Third, build and install oneDNN library from source: diff --git a/pyproject.toml b/pyproject.toml index 14f0934499c46..4e1841484420a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,8 @@ requires = [ "cmake>=3.26", "ninja", "packaging", - "setuptools >= 49.4.0", + "setuptools>=61", + "setuptools-scm>=8.0", "torch == 2.4.0", "wheel", "jinja2", @@ -19,6 +20,10 @@ exclude = [ "examples/fp8/quantizer/quantize.py" ] +[tool.ruff.lint.per-file-ignores] +"vllm/version.py" = ["F401"] +"vllm/_version.py" = ["ALL"] + [tool.ruff.lint] select = [ # pycodestyle @@ -46,6 +51,9 @@ ignore = [ "UP032", ] +[tool.setuptools_scm] +version_file = "vllm/_version.py" + [tool.mypy] python_version = "3.8" diff --git a/requirements-build.txt b/requirements-build.txt index 3f08f5d67b6da..6144a56da8c47 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -2,7 +2,8 @@ cmake>=3.26 ninja packaging -setuptools>=49.4.0 +setuptools>=61 +setuptools-scm>=8 torch==2.4.0 wheel jinja2 diff --git a/setup.py b/setup.py index 60e31af0a8d39..85a2852136eaa 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,6 @@ import re import subprocess import sys -import warnings from pathlib import Path from shutil import which from typing import Dict, List @@ -14,6 +13,7 @@ from packaging.version import Version, parse from setuptools import Extension, find_packages, setup from setuptools.command.build_ext import build_ext +from setuptools_scm import get_version from torch.utils.cpp_extension import CUDA_HOME @@ -28,34 +28,6 @@ def load_module_from_path(module_name, path): ROOT_DIR = os.path.dirname(__file__) logger = logging.getLogger(__name__) - -def embed_commit_hash(): - try: - if "BUILDKITE_COMMIT" in os.environ: - # ci build - commit_id = os.environ["BUILDKITE_COMMIT"] - else: - commit_id = subprocess.check_output(["git", "rev-parse", "HEAD"], - encoding="utf-8").strip() - - commit_contents = f'__commit__ = "{commit_id}"\n' - - version_file = os.path.join(ROOT_DIR, "vllm", "commit_id.py") - with open(version_file, "w", encoding="utf-8") as f: - f.write(commit_contents) - - except subprocess.CalledProcessError as e: - warnings.warn(f"Failed to get commit hash:\n{e}", - RuntimeWarning, - stacklevel=2) - except Exception as e: - warnings.warn(f"Failed to embed commit hash:\n{e}", - RuntimeWarning, - stacklevel=2) - - -embed_commit_hash() - # cannot import envs directly because it depends on vllm, # which is not installed yet envs = load_module_from_path('envs', os.path.join(ROOT_DIR, 'vllm', 'envs.py')) @@ -381,21 +353,9 @@ def get_path(*filepath) -> str: return os.path.join(ROOT_DIR, *filepath) -def find_version(filepath: str) -> str: - """Extract version information from the given filepath. - - Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py - """ - with open(filepath) as fp: - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - fp.read(), re.M) - if version_match: - return version_match.group(1) - raise RuntimeError("Unable to find version string.") - - def get_vllm_version() -> str: - version = find_version(get_path("vllm", "version.py")) + version = get_version() + sep = "+" if "+" not in version else "." # dev versions might contain + if _no_device(): if envs.VLLM_TARGET_DEVICE == "empty": @@ -406,27 +366,27 @@ def get_vllm_version() -> str: cuda_version_str = cuda_version.replace(".", "")[:3] # skip this for source tarball, required for pypi if "sdist" not in sys.argv: - version += f"+cu{cuda_version_str}" + version += f"{sep}cu{cuda_version_str}" elif _is_hip(): # Get the HIP version hipcc_version = get_hipcc_rocm_version() if hipcc_version != MAIN_CUDA_VERSION: rocm_version_str = hipcc_version.replace(".", "")[:3] - version += f"+rocm{rocm_version_str}" + version += f"{sep}rocm{rocm_version_str}" elif _is_neuron(): # Get the Neuron version neuron_version = str(get_neuronxcc_version()) if neuron_version != MAIN_CUDA_VERSION: neuron_version_str = neuron_version.replace(".", "")[:3] - version += f"+neuron{neuron_version_str}" + version += f"{sep}neuron{neuron_version_str}" elif _is_openvino(): - version += "+openvino" + version += f"{sep}openvino" elif _is_tpu(): - version += "+tpu" + version += f"{sep}tpu" elif _is_cpu(): - version += "+cpu" + version += f"{sep}cpu" elif _is_xpu(): - version += "+xpu" + version += f"{sep}xpu" else: raise RuntimeError("Unknown runtime environment") diff --git a/tests/test_embedded_commit.py b/tests/test_embedded_commit.py index 17b01651e39af..ffeacf34b7baf 100644 --- a/tests/test_embedded_commit.py +++ b/tests/test_embedded_commit.py @@ -2,6 +2,7 @@ def test_embedded_commit_defined(): - assert vllm.__commit__ != "COMMIT_HASH_PLACEHOLDER" - # 7 characters is the length of a short commit hash - assert len(vllm.__commit__) >= 7 + assert hasattr(vllm, "__version__") + assert hasattr(vllm, "__version_tuple__") + assert vllm.__version__ != "dev" + assert vllm.__version_tuple__ != (0, 0, "dev") diff --git a/vllm/__init__.py b/vllm/__init__.py index 59af68fb493e5..8f477ea84756d 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -12,11 +12,11 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from .version import __commit__, __version__ +from .version import __version__, __version_tuple__ __all__ = [ - "__commit__", "__version__", + "__version_tuple__", "LLM", "ModelRegistry", "PromptType", diff --git a/vllm/version.py b/vllm/version.py index 0ddc7fb99ad45..66e189dcedf71 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -1,13 +1,11 @@ -import warnings - try: - import vllm.commit_id - - __commit__ = vllm.commit_id.__commit__ + from ._version import __version__, __version_tuple__ except Exception as e: + import warnings + warnings.warn(f"Failed to read commit hash:\n{e}", RuntimeWarning, stacklevel=2) - __commit__ = "COMMIT_HASH_PLACEHOLDER" -__version__ = "0.6.1.post2" + __version__ = "dev" + __version_tuple__ = (0, 0, __version__) From 86e9c8df29a954a7a2fc46e9985fecc2a2e15ae8 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 23 Sep 2024 13:46:26 -0400 Subject: [PATCH 02/50] [Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin (#7701) Co-authored-by: mgoin Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Co-authored-by: Tyler Michael Smith --- CMakeLists.txt | 1 + benchmarks/kernels/benchmark_machete.py | 74 ++++++-- benchmarks/kernels/requirements.txt | 1 + csrc/cutlass_extensions/torch_utils.hpp | 8 +- csrc/ops.h | 2 + csrc/permute_cols.cu | 88 +++++++++ csrc/quantization/machete/generate.py | 173 +++++++++++++----- .../machete/machete_mm_kernel.cuh | 3 +- .../machete/machete_mm_launcher.cuh | 2 +- .../machete/machete_prepack_launcher.cuh | 2 +- csrc/torch_bindings.cpp | 3 + tests/kernels/test_machete_gemm.py | 3 + tests/kernels/test_permute_cols.py | 15 ++ vllm/_custom_ops.py | 19 +- .../layers/quantization/awq_marlin.py | 9 +- .../schemes/compressed_tensors_wNa16.py | 114 ++++-------- .../layers/quantization/gptq_marlin.py | 133 +++++--------- .../quantization/kernels/MPLinearKernel.py | 83 +++++++++ .../layers/quantization/kernels/__init__.py | 72 ++++++++ .../layers/quantization/kernels/machete.py | 118 ++++++++++++ .../layers/quantization/kernels/marlin.py | 132 +++++++++++++ .../layers/quantization/utils/__init__.py | 3 + .../layers/quantization/utils/layer_utils.py | 33 ++++ .../quantization/utils/machete_utils.py | 30 +++ .../layers/quantization/utils/marlin_utils.py | 29 +-- .../layers/quantization/utils/quant_utils.py | 43 +++++ vllm/model_executor/parameter.py | 58 ++++++ 27 files changed, 1005 insertions(+), 246 deletions(-) create mode 100644 benchmarks/kernels/requirements.txt create mode 100644 csrc/permute_cols.cu create mode 100644 tests/kernels/test_permute_cols.py create mode 100644 vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py create mode 100644 vllm/model_executor/layers/quantization/kernels/__init__.py create mode 100644 vllm/model_executor/layers/quantization/kernels/machete.py create mode 100644 vllm/model_executor/layers/quantization/kernels/marlin.py create mode 100644 vllm/model_executor/layers/quantization/utils/layer_utils.py create mode 100644 vllm/model_executor/layers/quantization/utils/machete_utils.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 2a04cd49c85a5..a05b53cba43f5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -223,6 +223,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/fp8/fp8_marlin.cu" "csrc/custom_all_reduce.cu" + "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index ca45cba6f8165..b70c4b94c97a1 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -4,8 +4,10 @@ import math import pickle as pkl import time -from typing import Callable, Iterable, List, Tuple +from itertools import product +from typing import Callable, Iterable, List, Optional, Tuple +import pandas as pd import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement @@ -84,6 +86,10 @@ def loop_over_weights( fn(a, w_ref, w_q, w_s) +_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None +_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None + + def bench(atype: torch.dtype, wtype: ScalarType, group_size: int, @@ -94,6 +100,8 @@ def bench(atype: torch.dtype, sub_label: str, benchmark_marlinv1: bool = True, sweep_schedules: bool = True) -> Iterable[TMeasurement]: + global _SWEEP_SCHEDULES_RESULTS + a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k) sub_label += f", L={len(weights)}" @@ -163,6 +171,11 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: best_schedule = None schedules = ops.machete_supported_schedules(wtype) for schedule in reversed(schedules): + schedule_M = int(schedule.split("_")[0].split("x")[1]) + + # Prune known bad schedules + if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: + continue def run(a, _, w_q, w_s, schedule=schedule): ops.machete_gemm(a, @@ -175,6 +188,20 @@ def run(a, _, w_q, w_s, schedule=schedule): res = bench_fn(label, sub_label, "machete_best", lambda: loop_over_weights(a, weights_machete, run)) + results_row = { + "M": m, + "K": k, + "N": n, + "group_size": group_size, + "schedule": schedule, + "median": res.median, + } + if _SWEEP_SCHEDULES_RESULTS is None: + _SWEEP_SCHEDULES_RESULTS = pd.DataFrame( + columns=results_row.keys()) + _SWEEP_SCHEDULES_RESULTS.\ + loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row + print(f" {res.median:5.5} ", schedule) if not best or res.median < best.median: best = res @@ -235,18 +262,22 @@ def run_square_bench(args): dim_sizes = list( range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, args.sweep_schedules, MKNs) make_output(data, MKNs, f"square_bench-{args.dtype}") def run_range_bench(args): - dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) - n = len(dim_sizes) - Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes - Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes - Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes - MKNs = list(zip(Ms, Ks, Ns)) + m_start, k_start, n_start = [int(x) for x in args.dim_start.split(",")] + m_end, k_end, n_end = [int(x) for x in args.dim_end.split(",")] + m_increment, k_increment, n_increment = \ + [int(x) for x in args.dim_increment.split(",")] + Ms = list(range(m_start, m_end + 1, m_increment)) + Ks = list(range(k_start, k_end + 1, k_increment)) + Ns = list(range(n_start, n_end + 1, n_increment)) + MKNs = list(product(Ms, Ks, Ns)) + data = run(args.dtype, args.sweep_schedules, MKNs) make_output(data, MKNs, f"range_bench-{args.dtype}") @@ -333,6 +364,9 @@ def to_torch_dtype(dt): action="store_true", help="Run a sweep over all supported schedules", ) + parser.add_argument("--sweep-csv-out", + help="CSV to store sweep results", + default="sch_sweep_results.csv") subparsers = parser.add_subparsers(dest="cmd", required=True) square_parser = subparsers.add_parser("square_bench") @@ -342,12 +376,21 @@ def to_torch_dtype(dt): square_parser.set_defaults(func=run_square_bench) range_parser = subparsers.add_parser("range_bench") - range_parser.add_argument("--dim-start", type=int, required=True) - range_parser.add_argument("--dim-end", type=int, required=True) - range_parser.add_argument("--dim-increment", type=int, required=True) - range_parser.add_argument("--m-constant", type=int, default=None) - range_parser.add_argument("--n-constant", type=int, default=None) - range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.add_argument( + "--dim-start", + type=str, + required=True, + help="Start value for M,K,N as common separated list") + range_parser.add_argument( + "--dim-end", + type=str, + required=True, + help="End value (inclusive) for M,K,N as common separated list") + range_parser.add_argument( + "--dim-increment", + type=str, + required=True, + help="Increment value for M,K,N as common separated list") range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") @@ -369,4 +412,9 @@ def to_torch_dtype(dt): model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() + + _SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out args.func(args) + + if _SWEEP_SCHEDULES_RESULTS is not None: + _SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV) diff --git a/benchmarks/kernels/requirements.txt b/benchmarks/kernels/requirements.txt new file mode 100644 index 0000000000000..1411a4a0b5ab8 --- /dev/null +++ b/benchmarks/kernels/requirements.txt @@ -0,0 +1 @@ +pandas \ No newline at end of file diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp index 1618a340ce10e..2c78572521eec 100644 --- a/csrc/cutlass_extensions/torch_utils.hpp +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -68,7 +68,13 @@ static inline auto make_cute_layout(torch::Tensor const& tensor, name, ".stride(", idx, ") to be ", StrideEle::value); return StrideEle{}; } else { - return tensor.stride(idx); + if (tensor.size(idx) == 1) { + // use 0 stride for dim with size 1, this is easier for + // cute/cutlass to optimize (helps the TMA code flatten dims) + return StrideEle{0}; + } else { + return tensor.stride(idx); + } } } else { // Extra strides are assumed to be 0 or 1 diff --git a/csrc/ops.h b/csrc/ops.h index 15e9ebe87408a..7ad0abd46c82a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -113,6 +113,8 @@ torch::Tensor prepack_B(torch::Tensor const& B, }; // namespace machete +torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); + torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, diff --git a/csrc/permute_cols.cu b/csrc/permute_cols.cu new file mode 100644 index 0000000000000..f51fa73298cc1 --- /dev/null +++ b/csrc/permute_cols.cu @@ -0,0 +1,88 @@ +#include + +#include +#include + +#include + +static constexpr int default_threads = 256; +static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +// Currently only supports 16bit types (since we permute half types) +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = std::max(finish_row - start_row, 0); + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +// More efficient version of A[..., perm] +// taken from gptq_marlin.cu +torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + auto dev = A.get_device(); + auto stream = at::cuda::getCurrentCUDAStream(dev); + + TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16, + "Currently only 16bit types are supported"); + TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); + TORCH_CHECK(A.size(-1) % 8 == 0, + "A columns must be a multiple of 8 (128bits)"); + auto A_2d = A.view({-1, A.size(-1)}); + + torch::Tensor D = torch::empty_like(A); + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + int block_rows = div_ceil(A_2d.size(0), sms); + permute_cols_kernel<<>>( + reinterpret_cast(A_2d.const_data_ptr()), + perm.const_data_ptr(), reinterpret_cast(D.mutable_data_ptr()), + A_2d.size(0), A_2d.size(1), block_rows); + return D; +} \ No newline at end of file diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 09a98a5dd1fd6..8ed81ea727aa3 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -157,7 +157,7 @@ TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative -@dataclass +@dataclass(frozen=True) class ScheduleConfig: tile_shape_mn: Tuple[int, int] cluster_shape_mnk: Tuple[int, int, int] @@ -328,56 +328,137 @@ def generate(): # about how this works SCRIPT_DIR = os.path.dirname(__file__) - schedules = [ - ScheduleConfig( - tile_shape_mn=tile_shape_mn, - cluster_shape_mnk=cluster_shape_mnk, - kernel_schedule=kernel_schedule, - epilogue_schedule=epilogue_schedule, - tile_scheduler=tile_scheduler, - ) for tile_shape_mn, cluster_shape_mnk in ( - ((128, 16), (1, 1, 1)), - ((128, 32), (1, 1, 1)), - ((128, 64), (1, 1, 1)), - ((128, 128), (1, 1, 1)), - ) for kernel_schedule in (TmaMI, ) for epilogue_schedule in (TmaCoop, ) - for tile_scheduler in (TileSchedulerType.StreamK, ) - ] + schedule_common_params = dict( + kernel_schedule=TmaMI, + epilogue_schedule=TmaCoop, + tile_scheduler=TileSchedulerType.StreamK, + ) # For now we use the same heuristic for all types + # Heuristic is currently tuned for H100s default_heuristic = [ - ("M > 64", - ScheduleConfig( - tile_shape_mn=(128, 128), - cluster_shape_mnk=(1, 1, 1), - kernel_schedule=TmaMI, - epilogue_schedule=TmaCoop, - tile_scheduler=TileSchedulerType.StreamK, - )), - ("M > 32", - ScheduleConfig( - tile_shape_mn=(128, 64), - cluster_shape_mnk=(1, 1, 1), - kernel_schedule=TmaMI, - epilogue_schedule=TmaCoop, - tile_scheduler=TileSchedulerType.StreamK, - )), - ("M > 16", - ScheduleConfig( - tile_shape_mn=(128, 32), - cluster_shape_mnk=(1, 1, 1), - kernel_schedule=TmaMI, - epilogue_schedule=TmaCoop, - tile_scheduler=TileSchedulerType.StreamK, - )), - (None, - ScheduleConfig(tile_shape_mn=(128, 16), - cluster_shape_mnk=(1, 1, 1), - kernel_schedule=TmaMI, - epilogue_schedule=TmaCoop, - tile_scheduler=TileSchedulerType.StreamK)) + #### M = 257+ + ( + "M > 256 && K <= 16384 && N <= 4096", + ScheduleConfig( + tile_shape_mn=(128, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 256", + ScheduleConfig( + tile_shape_mn=(128, 256), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 129-256 + ( + "M > 128 && K <= 4096 && N <= 4096", + ScheduleConfig( + tile_shape_mn=(128, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 128 && K <= 8192 && N <= 8192", + ScheduleConfig( + tile_shape_mn=(128, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 128", + ScheduleConfig( + tile_shape_mn=(128, 256), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 65-128 + ( + "M > 64 && K <= 4069 && N <= 4069", + ScheduleConfig( + tile_shape_mn=(128, 32), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 64 && K <= 4069 && N <= 8192", + ScheduleConfig( + tile_shape_mn=(128, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 64 && K >= 8192 && N >= 12288", + ScheduleConfig( + tile_shape_mn=(256, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 64", + ScheduleConfig( + tile_shape_mn=(128, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 33-64 + ( + "M > 32 && K <= 6144 && N <= 6144", + ScheduleConfig( + tile_shape_mn=(128, 16), + cluster_shape_mnk=(1, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 32 && K >= 16384 && N >= 12288", + ScheduleConfig( + tile_shape_mn=(256, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 32", + ScheduleConfig( + tile_shape_mn=(128, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 17-32 + ( + "M > 16 && K <= 12288 && N <= 8192", + ScheduleConfig( + tile_shape_mn=(128, 32), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 16", + ScheduleConfig( + tile_shape_mn=(256, 32), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 1-16 + ( + "N >= 26624", + ScheduleConfig( + tile_shape_mn=(256, 16), + cluster_shape_mnk=(1, 1, 1), + **schedule_common_params # type: ignore + )), + ( + None, + ScheduleConfig( + tile_shape_mn=(128, 16), + cluster_shape_mnk=(1, 1, 1), + **schedule_common_params # type: ignore + )), ] + schedules = list(set([x[1] for x in default_heuristic])) + impl_configs = [] GPTQ_kernel_type_configs = list( diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh index 046e6e5a53652..4d41b8d291484 100644 --- a/csrc/quantization/machete/machete_mm_kernel.cuh +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -152,7 +152,8 @@ struct MacheteKernelTemplate { int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); - int const group_size = maybe_group_size.value_or(K); + int const group_size = + maybe_group_size == -1 ? K : maybe_group_size.value_or(K); int const scale_k = (K + group_size - 1) / group_size; TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); diff --git a/csrc/quantization/machete/machete_mm_launcher.cuh b/csrc/quantization/machete/machete_mm_launcher.cuh index e2604d4bed3e2..60a4ed60535b7 100644 --- a/csrc/quantization/machete/machete_mm_launcher.cuh +++ b/csrc/quantization/machete/machete_mm_launcher.cuh @@ -71,7 +71,7 @@ torch::Tensor run_impl(PyTorchArguments args) { auto arguments = MacheteKernel::create_arguments( stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr, layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0), - args.group_size.value_or(K)); + args.group_size); TORCH_CHECK(MacheteKernel::can_implement(arguments), "Machete kernel cannot be run with these arguments"); diff --git a/csrc/quantization/machete/machete_prepack_launcher.cuh b/csrc/quantization/machete/machete_prepack_launcher.cuh index 686dd68bd52bb..df78312997fb0 100644 --- a/csrc/quantization/machete/machete_prepack_launcher.cuh +++ b/csrc/quantization/machete/machete_prepack_launcher.cuh @@ -53,7 +53,7 @@ torch::Tensor prepack_impl(torch::Tensor const B) { // clang-format on // Allocate output - torch::Tensor D = torch::empty_like(B); + torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous); prepack_B(stream, B_ptr, layout_Bt, static_cast(D.mutable_data_ptr())); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 045203c3de8a8..4b374af5ae24e 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -192,6 +192,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "-> Tensor"); ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B); + ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); + ops.impl("permute_cols", torch::kCUDA, &permute_cols); + // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py index 0a90882223077..0dfa79e9af8ec 100644 --- a/tests/kernels/test_machete_gemm.py +++ b/tests/kernels/test_machete_gemm.py @@ -31,6 +31,8 @@ (257, 4224, 4160), (257, 4096, 4096), (64, 4096, 4096), + (1024, 4096, 8192), + (1024, 8192, 4096), ] ACT_TYPES = [torch.float16, torch.bfloat16] @@ -139,6 +141,7 @@ def test_machete_all_schedules(shape, atype: torch.dtype, output_ref = torch.matmul(a, w_ref) for schedule in ops.machete_supported_schedules(wtype): + print(f"Testing schedule {schedule}") output = ops.machete_gemm( a, b_q=w_q_machete, diff --git a/tests/kernels/test_permute_cols.py b/tests/kernels/test_permute_cols.py new file mode 100644 index 0000000000000..14ad7a22cf7cf --- /dev/null +++ b/tests/kernels/test_permute_cols.py @@ -0,0 +1,15 @@ +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm._custom_ops import permute_cols + + +@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)]) +@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16]) +def test_permute_cols(shape, dtype): + x = torch.randn(shape, dtype=dtype).cuda() + perm = torch.randperm(x.shape[1]).to(torch.int).cuda() + opcheck(torch.ops._C.permute_cols, (x, perm)) + y = permute_cols(x, perm) + torch.testing.assert_close(y, x[:, perm]) \ No newline at end of file diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 678700055c992..a71bafc974adf 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -438,7 +438,8 @@ def machete_gemm_fake( @torch.library.register_fake("_C::machete_prepack_B") def machete_prepack_B_fake(b_q_weight: torch.Tensor, b_type: ScalarType) -> torch.Tensor: - return torch.empty_like(b_q_weight) + return torch.empty_like(b_q_weight, + memory_format=torch.contiguous_format) @torch.library.register_fake("_C::causal_conv1d_fwd") def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, @@ -625,6 +626,22 @@ def machete_prepack_B(b_q_weight: torch.Tensor, return torch.ops._C.machete_prepack_B(b_q_weight, b_type) +# TODO: has to be a better way to do this +try: + torch.ops._C.permute_cols # noqa B018 + + @torch.library.register_fake("_C::permute_cols") + def _permute_cols_fake(a: torch.Tensor, + perm: torch.Tensor) -> torch.Tensor: + return torch.empty_like(a) +except Exception: + pass + + +def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: + return torch.ops._C.permute_cols(a, perm) + + # fp8 def scaled_fp8_quant( input: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index eed01953fb4af..fe33b7341fd38 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -7,10 +7,11 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) @@ -231,7 +232,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qweight", marlin_qweight) + replace_parameter(layer, "qweight", marlin_qweight) # Permute scales from AWQ format to marlin format. marlin_scales = marlin_permute_scales( @@ -239,7 +240,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, group_size=self.quant_config.group_size) - replace_tensor(layer, "scales", marlin_scales) + replace_parameter(layer, "scales", marlin_scales) # Permute zero-points from AWQ format to marlin format. marlin_zp = awq_to_marlin_zero_points( @@ -247,7 +248,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.num_groups, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qzeros", marlin_zp) + replace_parameter(layer, "qzeros", marlin_zp) # Not-used layer.g_idx = marlin_make_empty_g_idx(device) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 3cade3d3fbcd0..cb65557be8f90 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -1,17 +1,16 @@ -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Set import torch -from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( ActivationOrdering) +from vllm.model_executor.layers.quantization.kernels import ( + MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, marlin_repeat_scales_on_all_ranks, - marlin_sort_g_idx, replace_tensor, verify_marlin_supported, - verify_marlin_supports_shape) + marlin_repeat_scales_on_all_ranks) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -19,6 +18,8 @@ RowvLLMParameter) from vllm.scalar_type import scalar_types +logger = init_logger(__name__) + __all__ = ["CompressedTensorsWNA16"] WNA16_SUPPORTED_TYPES_MAP = { 4: scalar_types.uint4b8, @@ -28,6 +29,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): + _kernel_backends_being_used: Set[str] = set() def __init__(self, strategy: str, @@ -52,35 +54,43 @@ def __init__(self, self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] - # Verify supported on platform. - verify_marlin_supported(quant_type=self.quant_type, - group_size=self.group_size) - @classmethod def get_min_capability(cls) -> int: # ampere and up return 80 - def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: List[int], + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: List[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): output_size_per_partition = sum(output_partition_sizes) + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_type, + act_type=params_dtype, + group_size=self.group_size, + zero_points=False, + has_g_idx=self.has_g_idx + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsWNA16", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + # If group_size is -1, we are in channelwise case. group_size = self.group_size if self.group_size != -1 else input_size row_parallel = (input_size != input_size_per_partition) partition_scales = not marlin_repeat_scales_on_all_ranks( self.has_g_idx, self.group_size, row_parallel) - verify_marlin_supports_shape( - output_size_per_partition=output_size_per_partition, - input_size_per_partition=input_size_per_partition, - input_size=input_size, - group_size=group_size) - scales_and_zp_size = input_size // group_size if partition_scales: @@ -137,69 +147,17 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, weight_loader=weight_loader) layer.register_parameter("weight_g_idx", weight_g_idx) - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.group_size = group_size + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name=None, + w_gidx_param_name="weight_g_idx") # Checkpoints are serialized in compressed-tensors format, which is - # different from marlin format. Handle repacking here. + # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.weight_packed.device - - # Allocate marlin workspace. - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) - - # Handle sorting for activation reordering if needed. - if self.has_g_idx: - g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx) - layer.g_idx_sort_indices = g_idx_sort_indices - replace_tensor(layer, "weight_g_idx", g_idx) - else: - layer.weight_g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.weight_zp = marlin_make_empty_g_idx(device) - # Update for kernel - layer.weight_packed = torch.nn.Parameter( - layer.weight_packed.t().contiguous(), requires_grad=False) - layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.squeeze().t().contiguous(), requires_grad=False) - - # Repack weights from compressed-tensors format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.weight_packed, - perm=layer.g_idx_sort_indices, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=self.quant_type.size_bits) - replace_tensor(layer, "weight_packed", marlin_qweight) - - # Permute scales from compressed-tensors format to marlin format. - # scale is required on all partitions if activation reordering - marlin_scales = marlin_permute_scales( - layer.weight_scale, - size_k=(layer.input_size - if self.has_g_idx else layer.input_size_per_partition), - size_n=layer.output_size_per_partition, - group_size=layer.group_size) - replace_tensor(layer, "weight_scale", marlin_scales) + self.kernel.process_weights_after_loading(layer) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: - - return apply_gptq_marlin_linear( - input=x, - weight=layer.weight_packed, - weight_scale=layer.weight_scale, - weight_zp=layer.weight_zp, - g_idx=layer.weight_g_idx, - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=layer.workspace, - wtype=self.quant_type, - output_size_per_partition=layer.output_size_per_partition, - input_size_per_partition=layer.input_size_per_partition, - is_k_full=True, - bias=bias) + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 5a1b2d701ab0d..3d3ce711e58b0 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,7 +1,6 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union import torch -from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -11,12 +10,12 @@ set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.kernels import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, - marlin_permute_scales, marlin_repeat_scales_on_all_ranks, - marlin_sort_g_idx, replace_tensor, verify_marlin_supported, - verify_marlin_supports_shape) + check_marlin_supported, marlin_moe_permute_scales, + marlin_repeat_scales_on_all_ranks, verify_marlin_supported) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -159,6 +158,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase): quant_config: The GPTQ Marlin quantization config. """ + _kernel_backends_being_used: Set[str] = set() + def __init__(self, quant_config: GPTQMarlinConfig) -> None: self.quant_config = quant_config @@ -176,25 +177,34 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: - - del output_size output_size_per_partition = sum(output_partition_sizes) is_row_parallel = input_size != input_size_per_partition weight_loader = extra_weight_attrs.get("weight_loader") + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + has_g_idx=self.quant_config.desc_act + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for GPTQMarlinLinearMethod", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + # Normalize group_size if self.quant_config.group_size != -1: group_size = self.quant_config.group_size else: group_size = input_size - verify_marlin_supports_shape( - output_size_per_partition=output_size_per_partition, - input_size_per_partition=input_size_per_partition, - input_size=input_size, - group_size=group_size, - ) - # Determine sharding if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, self.quant_config.group_size, @@ -275,57 +285,15 @@ def create_weights( layer.register_parameter("g_idx", g_idx) layer.register_parameter("scales", scales) layer.register_parameter("qzeros", qzeros) - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act, - is_row_parallel) - - # Checkpoints are serialized in AutoGPTQ format, which is different from the - # marlin format. This function is called after the weights are loaded. - # Here, we handle the repacking, including the activation reordering case. - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.qweight.device - # required by torch.compile - layer.qweight = Parameter(layer.qweight.data, requires_grad=False) - layer.scales = Parameter(layer.scales.data, requires_grad=False) + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx") - # Allocate marlin workspace - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) - - # Handle sorting for activation reordering if needed. - if self.quant_config.desc_act: - g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx) - layer.g_idx_sort_indices = g_idx_sort_indices - replace_tensor(layer, "g_idx", g_idx) - else: - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.zp = marlin_make_empty_g_idx(device) - - # Repack weights from autogptq format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.qweight, - perm=layer.g_idx_sort_indices, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits, - ) - replace_tensor(layer, "qweight", marlin_qweight) - - # Permute scales from autogptq format to marlin format. - marlin_scales = marlin_permute_scales( - layer.scales, - size_k=(layer.input_size if self.quant_config.desc_act else - layer.input_size_per_partition), - size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size, - ) - replace_tensor(layer, "scales", marlin_scales) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) def apply( self, @@ -333,20 +301,7 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return apply_gptq_marlin_linear( - input=x, - weight=layer.qweight, - weight_scale=layer.scales, - weight_zp=layer.zp, - g_idx=layer.g_idx, - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=layer.workspace, - wtype=self.quant_config.quant_type, - output_size_per_partition=layer.output_size_per_partition, - input_size_per_partition=layer.input_size_per_partition, - is_k_full=layer.is_k_full, - bias=bias, - ) + return self.kernel.apply_weights(layer, x, bias) class GPTQMarlinMoEMethod(FusedMoEMethodBase): @@ -506,12 +461,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13_g_idx_sort_indices[e]] w2_sorted_g_idx[e] = layer.w2_g_idx[e][ w2_g_idx_sort_indices[e]] - replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx) - replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx) - replace_tensor(layer, "w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - replace_tensor(layer, "w2_g_idx_sort_indices", - w2_g_idx_sort_indices) + replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx) + replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx) + replace_parameter(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) else: # Reset g_idx related tensors num_experts = layer.w13_g_idx.shape[0] @@ -544,7 +499,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_qweight.shape[2], self.quant_config.quant_type.size_bits, ) - replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) marlin_w2_qweight = ops.gptq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, @@ -552,7 +507,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_qweight.shape[2], self.quant_config.quant_type.size_bits, ) - replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) # Repack scales marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, @@ -560,14 +515,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_n=layer.w13_scales.shape[2], group_size=self.quant_config.group_size, ) - replace_tensor(layer, "w13_scales", marlin_w13_scales) + replace_parameter(layer, "w13_scales", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, size_n=layer.w2_scales.shape[2], group_size=self.quant_config.group_size, ) - replace_tensor(layer, "w2_scales", marlin_w2_scales) + replace_parameter(layer, "w2_scales", marlin_w2_scales) def apply( self, diff --git a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py new file mode 100644 index 0000000000000..fe50c4930d043 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py @@ -0,0 +1,83 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Optional, Tuple + +import torch + +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.scalar_type import ScalarType + + +@dataclass +class MPLinearLayerConfig: + full_weight_shape: Tuple[int, int] # [in, out] + partition_weight_shape: Tuple[int, int] + weight_type: ScalarType + act_type: torch.dtype + group_size: int + zero_points: bool + has_g_idx: bool + + +class MPLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + self.w_zp_name = w_zp_param_name + self.w_gidx_name = w_gidx_param_name + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _transform_param(self, layer: torch.nn.Module, name: Optional[str], + fn: Callable) -> None: + if name is not None and getattr(layer, name, None) is not None: + + old_param = getattr(layer, name) + new_param = fn(old_param) + # replace the parameter with torch.nn.Parameter for TorchDynamo + # compatibility + replace_parameter( + layer, name, + torch.nn.Parameter(new_param.data, requires_grad=False)) + + def _get_weight_params( + self, layer: torch.nn.Module + ) -> Tuple[torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], # w_zp, + Optional[torch.Tensor] # w_gidx + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.w_zp_name or "", None), + getattr(layer, self.w_gidx_name or "", None), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py new file mode 100644 index 0000000000000..47591c2aa644e --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/__init__.py @@ -0,0 +1,72 @@ +import os +from typing import List, Optional, Type + +from vllm.model_executor.layers.quantization.kernels.machete import ( + MacheteLinearKernel) +from vllm.model_executor.layers.quantization.kernels.marlin import ( + MarlinLinearKernel) +from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( + MPLinearKernel, MPLinearLayerConfig) +from vllm.platforms import current_platform + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ + MacheteLinearKernel, + MarlinLinearKernel, +] + + +def choose_mp_linear_kernel( + config: MPLinearLayerConfig, + compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: + """ + Choose an MPLinearKernel that can implement the given config for the given + compute capability. Attempts to choose the best kernel in terms of + performance. + + Args: + config (MPLinearLayerConfig): Description of the linear layer to be + implemented. + compute_capability (Optional[int], optional): The compute capability of + the target device, if None uses `current_platform` to get the compute + capability. Defaults to None. + + Raises: + ValueError: If no kernel can implement the given config. + + Returns: + Type[MPLinearKernel]: Chosen kernel. + """ + if compute_capability is None: + if current_platform is None: + raise ValueError("Cannot determine compute capability") + _cc = current_platform.get_device_capability() + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS: + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ + .split(","): + failure_reasons.append( + f' {kernel.__name__} disabled by environment variable') + continue + + if kernel.get_min_capability() > compute_capability: + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + continue + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "WNA16 linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/machete.py b/vllm/model_executor/layers/quantization/kernels/machete.py new file mode 100644 index 0000000000000..fa39cb511528e --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/machete.py @@ -0,0 +1,118 @@ +from functools import partial +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.machete_utils import ( + MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, + query_machete_supported_quant_types) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_weights_into_int32, unpack_weights_into_int32) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MacheteLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx and\ + c.partition_weight_shape[0] != c.full_weight_shape[0]: + return False, "Act reordering currently not supported by Machete, "\ + "when the input features are partitioned across "\ + "devices" + + if c.zero_points: + return False, "Zero points currently not supported by "\ + " Compressed Tensors + Machete. (Kernel supports it"\ + " but CompressedTensorsWNA16 does not so support has"\ + " not been added to MacheteWNA16Kernel yet" + + if c.weight_type not in query_machete_supported_quant_types( + c.zero_points): + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Machete, supported types are: "\ + f"{query_machete_supported_quant_types(c.zero_points)}" + + if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Machete, supported group sizes are: "\ + f"{MACHETE_SUPPORTED_GROUP_SIZES}" + + return check_machete_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1]) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + + if c.has_g_idx: + assert self.w_gidx_name is not None + perm = torch.argsort(getattr(layer, self.w_gidx_name))\ + .to(torch.int) + + self.act_perm = lambda x: x[:, perm] + # use `ops.permute_cols` if possible + if c.act_type in [torch.float16, torch.bfloat16] \ + and c.partition_weight_shape[0] % 8 == 0: + self.act_perm = partial(ops.permute_cols, perm=perm) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + if c.has_g_idx: + x_unpacked = unpack_weights_into_int32(x.data, + c.weight_type, + packed_dim=0) + x_perm = x_unpacked[perm, :] + x.data = pack_weights_into_int32(x_perm, + c.weight_type, + packed_dim=0) + x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), + self.config.weight_type) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + if c.has_g_idx: + x_2d = self.act_perm(x_2d) + + output = ops.machete_gemm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_zeros=None, + b_scales=w_s, + b_group_size=c.group_size) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/marlin.py b/vllm/model_executor/layers/quantization/kernels/marlin.py new file mode 100644 index 0000000000000..5b4bba76ee0ca --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/marlin.py @@ -0,0 +1,132 @@ +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, + check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, + marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, + query_marlin_supported_quant_types) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MarlinLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.zero_points: + return False, "Zero points currently not supported by "\ + " MarlinLinearKernel. Will be added when AWQMarlin "\ + "is migrated over to using MPLinearKernel backend" + + quant_types = query_marlin_supported_quant_types(c.zero_points) + if c.weight_type not in quant_types: + return False, f"Quant type ({c.weight_type}) not supported by"\ + f" Marlin, supported types are: {quant_types}" + + if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Marlin, supported group sizes are: "\ + f"{MARLIN_SUPPORTED_GROUP_SIZES}" + + return check_marlin_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1], + c.full_weight_shape[1], + c.group_size) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + + row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0]) + self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) + + # Allocate marlin workspace. + self.workspace = marlin_make_workspace(c.partition_weight_shape[1], + device) + + # Default names since marlin requires empty parameters for these, + # TODO: remove this requirement from marlin (allow optional tensors) + if self.w_gidx_name is None: + self.w_gidx_name = "g_idx" + if self.w_zp_name is None: + self.w_zp_name = "w_zp" + + if c.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + pass + # TODO (lucas): add the following when AWQMarlin is migrated over to + # using MPLinearKernel backend + # self._transform_param(layer, self.w_zp_name, lambda x: \ + # marlin_zero_points( + # x, + # size_k=c.partition_weight_shape[0], + # size_n=c.partition_weight_shape[1], + # num_bits=c.weight_type.size_bits)) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.gptq_marlin_repack(x.data.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = marlin_permute_scales(x.data.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size) + return x + + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) + + # `process_weights_after_loading` will ensure w_zp and w_gidx are not + # None for marlin + return apply_gptq_marlin_linear( + input=x, + weight=w_q, + weight_scale=w_s, + weight_zp=w_zp, # type: ignore + g_idx=w_gidx, # type: ignore + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=self.workspace, + wtype=c.weight_type, + input_size_per_partition=c.partition_weight_shape[0], + output_size_per_partition=c.partition_weight_shape[1], + is_k_full=self.is_k_full, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py index e69de29bb2d1d..e60f0c79ac1f7 100644 --- a/vllm/model_executor/layers/quantization/utils/__init__.py +++ b/vllm/model_executor/layers/quantization/utils/__init__.py @@ -0,0 +1,3 @@ +from .layer_utils import replace_parameter, update_tensor_inplace + +__all__ = ['update_tensor_inplace', 'replace_parameter'] diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py new file mode 100644 index 0000000000000..c38bd8955f457 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -0,0 +1,33 @@ +from typing import Union + +import torch + + +def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): + assert dst.dtype == src.dtype, "Tensors must have the same dtype" + + # update tensor shape and stride + dst.as_strided_(src.shape, src.stride()) + + # If not the same underlying storage move tensor data + if dst.data_ptr() != src.data_ptr(): + dst.copy_(src) + del src + + +# Newly generated tensors need to replace existing tensors that are +# already registered as parameters by vLLM (and won't be freed) +def replace_parameter(mod: torch.nn.Module, name: str, + new: Union[torch.Tensor, torch.nn.Parameter]) -> None: + + old = getattr(mod, name) + if old.dtype == new.dtype and \ + old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): + # If we can just update in-place to avoid re-registering + # can be faster if the underlying storage is the same + update_tensor_inplace(old, new) + else: + # Fallback re-register parameter + if not isinstance(new, torch.nn.Parameter): + new = torch.nn.Parameter(new) + mod.register_parameter(name, torch.nn.Parameter(new)) diff --git a/vllm/model_executor/layers/quantization/utils/machete_utils.py b/vllm/model_executor/layers/quantization/utils/machete_utils.py new file mode 100644 index 0000000000000..18e1332050cdd --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/machete_utils.py @@ -0,0 +1,30 @@ +from typing import List, Optional, Tuple + +import torch + +from vllm.scalar_type import ScalarType, scalar_types + +MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128] +MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128] + + +def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]: + if zero_points: + return [scalar_types.uint4, scalar_types.uint8] + else: + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]: + return [torch.float16, torch.bfloat16] + + +def check_machete_supports_shape(in_features: int, out_featrues: int) \ + -> Tuple[bool, Optional[str]]: + if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: + return False, "Input features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}" + if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: + return False, "Output features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}" + return True, None diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index fea94cf7322ad..53762965732ce 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -120,6 +120,19 @@ def verify_marlin_supports_shape(output_size_per_partition: int, "with --quantization gptq.") +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) + except ValueError as e: + return False, e.__str__() + return True, None + + def marlin_make_workspace(output_size_per_partition: int, device: torch.device) -> torch.Tensor: max_workspace_size = (output_size_per_partition // @@ -148,6 +161,11 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: requires_grad=False) +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + def marlin_sort_g_idx( g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) @@ -240,17 +258,6 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return marlin_zp -# Newly generated tensors need to replace existing tensors that are -# already registered as parameters by vLLM (and won't be freed) -def replace_tensor(layer: torch.nn.Module, name: str, - new_t: torch.Tensor) -> None: - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def apply_gptq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index bdfda31de852b..833d00073564e 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -20,6 +20,49 @@ } +def pack_weights_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_weights_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 9ffb339ffeab3..7a6d7c90f34d5 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -328,6 +328,64 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): marlin_tile_size=self.marlin_tile_size) +def permute_param_layout_(param: BasevLLMParameter, input_dim: int, + output_dim: int, **kwargs) -> BasevLLMParameter: + """ + Permute a parameter's layout to the specified input and output dimensions, + useful for forcing the parameter into a known layout, for example, if I need + a packed (quantized) weight matrix to be in the layout + {input_dim = 0, output_dim = 1, packed_dim = 0} + then I can call: + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + to ensure x is in the correct layout (permuting it to the correct layout if + required, asserting if it cannot get it to the correct layout) + """ + + curr_input_dim = getattr(param, "input_dim", None) + curr_output_dim = getattr(param, "output_dim", None) + + if curr_input_dim is None or curr_output_dim is None: + assert param.data.dim() == 2,\ + "permute_param_layout_ only supports 2D parameters when either "\ + "input_dim or output_dim is not set" + + # if one of the dimensions is not set, set it to the opposite of the other + # we can only do this since we asserted the parameter is 2D above + if curr_input_dim is None: + assert curr_output_dim is not None,\ + "either input or output dim must be set" + curr_input_dim = (curr_output_dim + 1) % 2 + if curr_output_dim is None: + assert curr_input_dim is not None,\ + "either input or output dim must be set" + curr_output_dim = (curr_input_dim + 1) % 2 + + # create permutation from the current layout to the layout with + # self.input_dim at input_dim and self.output_dim at output_dim preserving + # other dimensions + perm = [ + i for i in range(param.data.dim()) + if i not in [curr_input_dim, curr_output_dim] + ] + perm.insert(input_dim, curr_input_dim) + perm.insert(output_dim, curr_output_dim) + + if "packed_dim" in kwargs: + assert hasattr(param, "packed_dim") and\ + param.packed_dim == perm[kwargs["packed_dim"]],\ + "permute_param_layout_ currently doesn't support repacking" + + param.data = param.data.permute(*perm) + if hasattr(param, "_input_dim"): + param._input_dim = input_dim + if hasattr(param, "_output_dim"): + param._output_dim = output_dim + if "packed_dim" in kwargs and hasattr(param, "_packed_dim"): + param._packed_dim = kwargs["packed_dim"] + + return param + + def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size From 9b0e3ec970f6a19427be358848a2ed663fd735e1 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 24 Sep 2024 02:57:42 +0800 Subject: [PATCH 03/50] [Kernel][LoRA] Add assertion for punica sgmv kernels (#7585) --- tests/lora/test_punica_sizes.py | 5 ++++ tests/lora/test_punica_variation.py | 5 ++++ vllm/lora/ops/bgmv_expand.py | 2 +- vllm/lora/ops/bgmv_expand_slice.py | 2 +- vllm/lora/ops/sgmv_expand.py | 16 +++++++----- vllm/lora/ops/sgmv_expand_slice.py | 18 ++++++++------ vllm/lora/ops/sgmv_shrink.py | 16 +++++++----- vllm/lora/punica.py | 38 ++++++++++++++++------------- 8 files changed, 64 insertions(+), 38 deletions(-) diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 314d6215cbd9c..41c37a4813c68 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -169,6 +169,7 @@ def test_punica_sgmv( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -183,6 +184,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, scaling, ) else: @@ -195,6 +197,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, add_inputs=True, ) ref_torch_groupgemm( @@ -347,6 +350,7 @@ def test_punica_expand_nslices( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -364,6 +368,7 @@ def test_punica_expand_nslices( lora_indices_tensor, batches, max_seq_length, + token_nums, slice_offset, hidden_size, add_inputs=True, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 28a395af19e6d..185da6399a06a 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -84,6 +84,7 @@ def test_punica_sgmv( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -98,6 +99,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, scaling, ) else: @@ -110,6 +112,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, add_inputs=True, ) ref_torch_groupgemm( @@ -262,6 +265,7 @@ def test_punica_expand_nslices( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -279,6 +283,7 @@ def test_punica_expand_nslices( lora_indices_tensor, batches, max_seq_length, + token_nums, slice_offset, hidden_size, add_inputs=True, diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py index 619408b9315cf..6a32387a6f36c 100644 --- a/vllm/lora/ops/bgmv_expand.py +++ b/vllm/lora/ops/bgmv_expand.py @@ -100,7 +100,7 @@ def _bgmv_expand( corresponding to each batch, An index of -1 means no lora should be applied. batches (int): batch size - add_inputs (bool, optional): Defaults to False. adds the final lora + add_inputs (bool, optional): Defaults to False, adds the final lora results to the output. """ assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py index c16db233891a5..73628fd20d327 100644 --- a/vllm/lora/ops/bgmv_expand_slice.py +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -104,7 +104,7 @@ def _bgmv_expand_slice( lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch, An index of -1 means no lora should be applied. - slice_offst (int): output_tensor's offst + slice_offset (int): output_tensor's offset slice_size (int): current output_tensor's size batches (int): batch size add_inputs (bool, optional): Defaults to False. diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index c71332d8bdfb2..adb3ab5b46b87 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -106,6 +106,7 @@ def _sgmv_expand( lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, + token_nums: int, add_inputs: bool = False, ) -> None: """ @@ -115,17 +116,19 @@ def _sgmv_expand( output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index - into sequence. E.g.,if the sequence length is [4, 6], it is + into sequence. E.g., if the sequence length is [4, 6], it is [0, 4, 10]. - seq_len_tensor (torch.Tensor): (batch_size,). record the sequence - length of the sequences in the batch + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences - in the batch - add_inputs (bool, optional): Defaults to False. adds the final lora + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + add_inputs (bool, optional): Defaults to False, adds the final lora results to the output. """ @@ -134,6 +137,7 @@ def _sgmv_expand( torch.float16, torch.bfloat16, ] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_b_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index b4ae9a2acbb5c..efa234520ab87 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -112,6 +112,7 @@ def _sgmv_expand_slice( lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, + token_nums: int, slice_offset: int, slice_size: int, add_inputs: bool = False, @@ -124,20 +125,22 @@ def _sgmv_expand_slice( output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index - into sequence. E.g.,if the sequence length is [4, 6], it is + into sequence. E.g., if the sequence length is [4, 6], it is [0, 4, 10]. - seq_len_tensor (torch.Tensor): (batch_size,). record the sequence - length of the sequences in the batch + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences + max_seq_length (int): The max sequence lengths of the sequences in the batch - slice_offst (int): output_tensor's offst + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + slice_offset (int): output_tensor's offset slice_size (int): current output_tensor's size - add_inputs (bool, optional): Defaults to False. adds the final lora - results to the output.. + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. """ assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] @@ -145,6 +148,7 @@ def _sgmv_expand_slice( torch.float16, torch.bfloat16, ] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_b_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index c0791c260e915..c003f3dc0ce9e 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -110,6 +110,7 @@ def _sgmv_shrink( lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, + token_nums: int, scaling: float, ) -> None: """ @@ -120,17 +121,19 @@ def _sgmv_shrink( output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index - into sequence. E.g.,if the sequence length is [4, 6], it is + into sequence. E.g., if the sequence length is [4, 6], it is [0, 4]. - seq_len_tensor (torch.Tensor): (batch_size,). record the sequence - length of the sequences in the batch + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences - in the batch - scaling (float): Scaling factor. + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + scaling (float): Scaling factor. """ assert inputs.dtype == lora_a_weights.dtype assert inputs.dtype in [torch.float16, torch.bfloat16] @@ -138,6 +141,7 @@ def _sgmv_shrink( torch.float16, torch.bfloat16, ] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_a_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 6d5c834299961..5033ce4126929 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -27,7 +27,7 @@ def compute_meta( token_lora_tensor: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: """ Get the information required for the sgmv kernel. With the features: 1. If consecutive requests in the batch use the same LoRA, this function @@ -43,7 +43,7 @@ def compute_meta( b_seq_start_tensor = torch.zeros_like(seq_length_tensor) b_seq_start_tensor[1:].copy_(cum_result[:-1]) max_length = seq_length_tensor.max().item() - + token_nums = seq_length_tensor.sum().item() batch_size = lora_indices_tensor.size(0) no_lora = False # -1 means no lora should be applied. Use `no_lora` to determine whether @@ -52,7 +52,7 @@ def compute_meta( if batch_size == 1 and lora_indices_tensor == -1: no_lora = True return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, no_lora) + batch_size, max_length, token_nums, no_lora) # TODO see if this can be vectorized @@ -178,7 +178,7 @@ def convert_mapping( class PunicaWrapper: """ PunicaWrapper is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for + kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the punica kernel. """ @@ -216,6 +216,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, dtype=torch.long, device=device) self.max_length: int = 0 + self.token_nums: int = 0 self.batch_size: int = -1 self.is_prefill = False self.no_lora = False @@ -276,13 +277,13 @@ def _update_base_metadata( long_lora_offsets_tensor) else: self._long_lora_indices.zero_() - self.indices_len[:] = indices_len def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, no_lora) = compute_meta(token_lora_tensor) + batch_size, max_length, token_nums, + no_lora) = compute_meta(token_lora_tensor) self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( b_seq_start_tensor) @@ -291,25 +292,28 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: lora_indices_tensor) self.batch_size = batch_size self.max_length = max_length + self.token_nums = token_nums self.no_lora = no_lora @property def prefill_metadata( - self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: + self + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: """ This property provides a convenient way to access the necessary metadata for prefill-related kernel computations. - 1. seq_start_locs: Tensor of sequence start positions - 2. seq_lengths: Tensor of sequence lengths + 1. seq_start_locs: Tensor of sequence start positions. + 2. seq_lengths: Tensor of sequence lengths. 3. lora_indices_per_batch: Tensor of lora indices, and an index of -1 means no lora should be applied. - 4. batch_size: batch size after clustering identical lora indices - 5. max_length: The maximum sequence length in the batch + 4. batch_size: Batch size after clustering identical lora indices. + 5. max_length: The maximum sequence length in the batch. + 6. token_nums: The token numbers in the batch. """ return (self._seq_start_locs[:self.batch_size], self._seq_lengths[:self.batch_size], self._lora_indices_per_batch[:self.batch_size], - self.batch_size, self.max_length) + self.batch_size, self.max_length, self.token_nums) @property def token_lora_indices(self) -> torch.Tensor: @@ -324,7 +328,7 @@ def token_lora_indices(self) -> torch.Tensor: def sampler_indices(self) -> torch.Tensor: """ This property is used to access the lora indices specifically for - LogitsProcessorWithLoRA + LogitsProcessorWithLoRA. """ sampler_indices_len = self.indices_len[1] return self._sampler_indices[:sampler_indices_len] @@ -332,7 +336,7 @@ def sampler_indices(self) -> torch.Tensor: @property def sampler_indices_padded(self) -> torch.Tensor: """ - This property provides access to padded sampler indices + This property provides access to padded sampler indices. """ indices_padded_len = self.indices_len[2] return self._sampler_indices_padded[:indices_padded_len] @@ -341,7 +345,7 @@ def sampler_indices_padded(self) -> torch.Tensor: def embeddings_indices(self) -> torch.Tensor: """ This property provides access to the indices used for lora embeddings, - specifically for VocabParallelEmbeddingWithLoRA + specifically for VocabParallelEmbeddingWithLoRA. """ embeddings_indices_len = self.indices_len[3] return self._embeddings_indices[:, :embeddings_indices_len] @@ -350,7 +354,7 @@ def embeddings_indices(self) -> torch.Tensor: def long_lora_indices(self) -> torch.Tensor: """ This property provides access to the indices used for long context - lora, specifically for LinearScalingRotaryEmbeddingWithLora + lora, specifically for LinearScalingRotaryEmbeddingWithLora. """ long_lora_len = self.indices_len[4] return self._long_lora_indices[:long_lora_len] @@ -524,7 +528,7 @@ def add_lora(self, scale (float): Scaling factor. y_offset (Optional[int], optional): Offset to apply to the starting column of y. - y_slice_size (Optional[int], optional): Size of the y column slice.. + y_slice_size (Optional[int], optional): Size of the y column slice. buffer (Optional[torch.Tensor], optional): Defaults to None. """ y_org = y From b05f5c9238c3e0c3a98080b4ffc90acfa33f9e1f Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 23 Sep 2024 15:15:41 -0400 Subject: [PATCH 04/50] [Core] Allow IPv6 in VLLM_HOST_IP with zmq (#8575) Signed-off-by: Russell Bryant --- vllm/distributed/device_communicators/shm_broadcast.py | 7 ++++++- vllm/utils.py | 9 +++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index b507cd2e1cddb..7d526b25ed193 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -9,11 +9,12 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from zmq import IPV6 # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import get_ip, get_open_port +from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL @@ -214,6 +215,8 @@ def __init__( self.remote_socket = context.socket(XPUB) self.remote_socket.setsockopt(XPUB_VERBOSE, True) remote_subscribe_port = get_open_port() + if is_valid_ipv6_address(connect_ip): + self.remote_socket.setsockopt(IPV6, 1) socket_addr = f"tcp://*:{remote_subscribe_port}" self.remote_socket.bind(socket_addr) @@ -274,6 +277,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.remote_socket = context.socket(SUB) self.remote_socket.setsockopt_string(SUBSCRIBE, "") + if is_valid_ipv6_address(handle.connect_ip): + self.remote_socket.setsockopt(IPV6, 1) socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}" logger.debug("Connecting to %s", socket_addr) self.remote_socket.connect(socket_addr) diff --git a/vllm/utils.py b/vllm/utils.py index db2ef146e38ea..b73e3b9bbf68e 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -5,6 +5,7 @@ import enum import gc import inspect +import ipaddress import os import random import socket @@ -533,6 +534,14 @@ def get_ip() -> str: return "0.0.0.0" +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + def get_distributed_init_method(ip: str, port: int) -> str: # Brackets are not permitted in ipv4 addresses, # see https://github.com/python/cpython/issues/103848 From 5f7bb584272ee15147a411b887e7ababd6b9b9d0 Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Tue, 24 Sep 2024 03:32:27 +0800 Subject: [PATCH 05/50] Fix typical acceptance sampler with correct recovered token ids (#8562) --- .../test_typical_acceptance_sampler.py | 17 ++++++----- .../layers/typical_acceptance_sampler.py | 28 ++++++------------- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 1eba98cefd04a..4ddad66dce1fb 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -365,7 +365,7 @@ def test_accept_tokens_partially(seed: int, device: str): # Next only keep the first 2 draft tokens same as the zero temperature # tokens. For the remaining 3 choose some other tokens. In the # response we will expect the first 2 tokens to be the same as the - # draft tokens and the rest as -1 + # draft tokens and the recovered token and rest as -1 draft_token_ids_to_replace = get_draft_token_ids( batch_size, k, vocab_size, zero_temperature_token_ids) draft_token_ids = torch.cat( @@ -378,6 +378,8 @@ def test_accept_tokens_partially(seed: int, device: str): assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2]) + assert torch.all( + output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2]) assert torch.all(output_token_ids[:, -3:] == -1) @@ -443,14 +445,14 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str): @pytest.mark.parametrize("seed", list(range(10))) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_replacement_token_ids(seed: int, device: str): +def test_get_recovered_token_ids(seed: int, device: str): """ Test the TypicalAcceptanceSampler's method for generating replacement token IDs. - This test verifies that the `_replacement_token_ids` method of the + This test verifies that the `_get_recovered_token_ids` method of the TypicalAcceptanceSampler correctly identifies the token IDs to be used - as replacements based on the target probability distribution. + as recovered token IDs based on the target probability distribution. Specifically, it ensures that the method correctly identifies the tokens with the highest probability for each sequence in the batch. """ @@ -462,10 +464,7 @@ def test_replacement_token_ids(seed: int, device: str): typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) typical_acceptance_sampler.init_gpu_tensors(device=device) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - expected_replacement_tokens = -torch.ones( - (batch_size, k), dtype=torch.long) - expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :], - dim=1) + expected_replacement_tokens = torch.argmax(target_probs, dim=-1) actual_replacement_tokens = ( - typical_acceptance_sampler._replacement_token_ids(target_probs)) + typical_acceptance_sampler._get_recovered_token_ids(target_probs)) assert torch.all(expected_replacement_tokens == actual_replacement_tokens) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 8c03e46927752..584cf971d9c05 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -80,7 +80,7 @@ def forward( target_probs = target_with_bonus_probs[:, :-1] accepted = self._evaluate_accepted_tokens(target_probs, draft_token_ids) - recovered_token_ids = self._replacement_token_ids(target_probs) + recovered_token_ids = self._get_recovered_token_ids(target_probs) output_token_ids = self._create_output(accepted, recovered_token_ids, draft_token_ids, bonus_token_ids) @@ -148,16 +148,10 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): accepted_mask = candidates_prob > threshold return accepted_mask - def _replacement_token_ids(self, target_probs): + def _get_recovered_token_ids(self, target_probs): """ - Generate one replacement token ID for each sequence based on target - probabilities. The replacement token is used as the fallback option - if typical acceptance sampling does not accept any draft tokens for - that particular sequence. - - This method computes the token IDs to be replaced by selecting the - token with the highest probability for each sequence in the first - position. The rest of the output is filled with -1. + The recovered token ids will fill the first unmatched token + by the target token. Parameters ---------- @@ -168,13 +162,9 @@ def _replacement_token_ids(self, target_probs): Returns ------- torch.Tensor - A tensor of shape (batch_size, k) with the replacement - token IDs. Only the first column is set, and the rest of the - columns are filled with -1. + A tensor of shape (batch_size, k) with the recovered token + ids which are selected from target probs. """ - max_indices = torch.argmax(target_probs[:, 0, :], dim=1) - output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), - dtype=self.token_id_dtype, - device=target_probs.device) - output[:, 0] = max_indices - return output + max_indices = torch.argmax(target_probs, dim=-1) + + return max_indices From 1a2aef3e59f5429299618bd3b242833cb377f554 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Mon, 23 Sep 2024 18:38:04 -0400 Subject: [PATCH 06/50] Add output streaming support to multi-step + async while ensuring RequestOutput obj reuse (#8335) --- tests/entrypoints/openai/test_accuracy.py | 6 +- vllm/config.py | 2 + vllm/engine/arg_utils.py | 6 ++ vllm/engine/llm_engine.py | 37 ++++++--- vllm/engine/multiprocessing/engine.py | 9 ++- vllm/outputs.py | 96 +++++++++++++++++------ vllm/sequence.py | 28 ++++--- 7 files changed, 142 insertions(+), 42 deletions(-) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index 2ad8460023c25..63beaaba29a80 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -19,7 +19,11 @@ RTOL = 0.03 EXPECTED_VALUE = 0.58 DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] -MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]] +MORE_ARGS_LIST = [ + ["--enable-chunked-prefill"], # Chunked + ["--num-scheduler-steps", "8"], # MS + ["--num-scheduler-steps", "8", "--multi-step-stream-outputs"] # MS+Stream +] @pytest.mark.parametrize("more_args", MORE_ARGS_LIST) diff --git a/vllm/config.py b/vllm/config.py index 960a8d3928584..8c65d99c44651 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -960,6 +960,7 @@ def __init__(self, is_multimodal_model: bool = False, preemption_mode: Optional[str] = None, num_scheduler_steps: int = 1, + multi_step_stream_outputs: bool = False, send_delta_data: bool = False) -> None: if max_num_batched_tokens is None: if enable_chunked_prefill: @@ -1000,6 +1001,7 @@ def __init__(self, self.embedding_mode = embedding_mode self.preemption_mode = preemption_mode self.num_scheduler_steps = num_scheduler_steps + self.multi_step_stream_outputs = multi_step_stream_outputs self.send_delta_data = send_delta_data self._verify_args() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ca6034ddbe5c5..0d4559e377427 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -145,6 +145,7 @@ class EngineArgs: max_cpu_loras: Optional[int] = None device: str = 'auto' num_scheduler_steps: int = 1 + multi_step_stream_outputs: bool = False ray_workers_use_nsight: bool = False num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 @@ -595,6 +596,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help=('Maximum number of forward steps per ' 'scheduler call.')) + parser.add_argument( + '--multi-step-stream-outputs', + action='store_true', + help='If True, then multi-step will stream outputs for every step') parser.add_argument( '--scheduler-delay-factor', type=float, @@ -999,6 +1004,7 @@ def create_engine_config(self) -> EngineConfig: is_multimodal_model=model_config.is_multimodal_model, preemption_mode=self.preemption_mode, num_scheduler_steps=self.num_scheduler_steps, + multi_step_stream_outputs=self.multi_step_stream_outputs, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 80dde804addac..1e77a01bfa9d9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -95,7 +95,7 @@ class OutputData(NamedTuple): class SchedulerContext: - def __init__(self): + def __init__(self, multi_step_stream_outputs: bool = False): self.output_queue: Deque[OutputData] = deque() self.request_outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] @@ -103,6 +103,8 @@ def __init__(self): List[SequenceGroupMetadata]] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None + self.multi_step_stream_outputs: bool = multi_step_stream_outputs + def append_output(self, outputs: List[SamplerOutput], seq_group_metadata_list: List[SequenceGroupMetadata], scheduler_outputs: SchedulerOutputs, is_async: bool, @@ -219,6 +221,7 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, + use_cached_outputs: bool = False, ) -> None: logger.info( "Initializing an LLM engine (v%s) with config: " @@ -234,8 +237,9 @@ def __init__( "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, mm_processor_kwargs=%s)", + "num_scheduler_steps=%d, multi_step_stream_outputs=%s, " + "enable_prefix_caching=%s, use_async_output_proc=%s, " + "use_cached_outputs=%s, mm_processor_kwargs=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -266,8 +270,10 @@ def __init__( model_config.served_model_name, scheduler_config.use_v2_block_manager, scheduler_config.num_scheduler_steps, + scheduler_config.multi_step_stream_outputs, cache_config.enable_prefix_caching, model_config.use_async_output_proc, + use_cached_outputs, model_config.mm_processor_kwargs, ) # TODO(woosuk): Print more configs in debug mode. @@ -287,6 +293,7 @@ def __init__( self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats + self.use_cached_outputs = use_cached_outputs if not self.model_config.skip_tokenizer_init: self.tokenizer = self._init_tokenizer() @@ -379,7 +386,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: ] self.scheduler_contexts = [ - SchedulerContext() + SchedulerContext(multi_step_stream_outputs=self.scheduler_config. + multi_step_stream_outputs) for _ in range(self.parallel_config.pipeline_parallel_size) ] @@ -998,7 +1006,8 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create(seq_group) + request_output = RequestOutputFactory.create( + seq_group, use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) @@ -1019,8 +1028,8 @@ def _process_model_outputs(self, for scheduler in self.scheduler: scheduler.free_finished_seq_groups() - # For multi-step, do not create outputs each iteration - if not is_last_step: + # For multi-step without streaming, don't create outputs each iteration + if not is_last_step and not ctx.multi_step_stream_outputs: # Immediately process request outputs here (if callback is given) if (finished_now and self.process_request_outputs_callback is not None): @@ -1037,17 +1046,27 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create(seq_group) + request_output = RequestOutputFactory.create( + seq_group, use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) + # For multi-step with streaming, create outputs each iteration + if not is_last_step and ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if self.process_request_outputs_callback is not None: + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + for seq_group in scheduler_outputs.ignored_seq_groups: params = seq_group.sampling_params if params is not None and params.output_kind == ( RequestOutputKind.DELTA) and not seq_group.is_finished(): continue - request_output = RequestOutputFactory.create(seq_group) + request_output = RequestOutputFactory.create( + seq_group, use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 788c1573ae255..3b0f617629d63 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -66,7 +66,14 @@ def __init__(self, *args, log_requests: bool = True, **kwargs) -> None: - self.engine = LLMEngine(*args, **kwargs) + # For MQLLMEngine, we can use cached outputs, since each new request + # output is immediately pickled and send over the socket, which frees + # the python object to be reused again. + use_cached_outputs = True + + self.engine = LLMEngine(*args, + **kwargs, + use_cached_outputs=use_cached_outputs) self.log_requests = log_requests self.use_async_sockets = use_async_sockets diff --git a/vllm/outputs.py b/vllm/outputs.py index 85ea9196b25df..44cde6b561d85 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -114,17 +114,28 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids @classmethod - def from_seq_group(cls, - seq_group: SequenceGroup) -> Optional["RequestOutput"]: + def from_seq_group(cls, seq_group: SequenceGroup, + use_cache: bool) -> Optional["RequestOutput"]: sampling_params = seq_group.sampling_params if sampling_params is None: raise ValueError( "Sampling parameters are missing for a CompletionRequest.") + finished = seq_group.is_finished() if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( not finished): return None + # Init cache (if needed) + if use_cache and seq_group.cached_request_output is None: + seq_group.cached_request_output = RequestOutput( # type: ignore + request_id="", + prompt=None, + prompt_token_ids=[], + prompt_logprobs=None, + outputs=[], + finished=False) + seqs = seq_group.get_seqs() if len(seqs) == 1: top_n_seqs = seqs @@ -149,29 +160,66 @@ def from_seq_group(cls, outputs = [] include_prompt = True - for seq in top_n_seqs: + for i, seq in enumerate(top_n_seqs): output_text = seq.get_output_text_to_return( text_buffer_length, delta) + output_token_ids = seq.get_output_token_ids_to_return(delta) + num_output_tokens = 1 if isinstance(output_token_ids, + int) else len(output_token_ids) + output_logprobs = seq.output_logprobs if include_logprobs else None if delta: # Slice logprobs delta if applicable if output_logprobs: - output_logprobs = output_logprobs[-len(output_token_ids):] + output_logprobs = output_logprobs[-num_output_tokens:] # Don't include prompt if this is after the first output # containing decode token ids - if include_prompt and seq.get_output_len() > len( - output_token_ids): + if include_prompt and seq.get_output_len() > num_output_tokens: include_prompt = False - outputs.append( - CompletionOutput( - seqs.index(seq), output_text, output_token_ids, + if use_cache: + # Get cached output object + cached_outputs = seq_group.cached_request_output.outputs # type: ignore + if i >= len(cached_outputs): + cached_outputs.append( + CompletionOutput(index=i, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + stop_reason=None)) + output = cached_outputs[i] + + # Init cached output object + assert output.index == i + output.text = output_text + + if isinstance(output_token_ids, int): + output.token_ids.clear() + output.token_ids.append(output_token_ids) + else: + output.token_ids = output_token_ids + + output.cumulative_logprob = seq.get_cumulative_logprob() \ + if include_logprobs else None + output.logprobs = output_logprobs + output.finish_reason = SequenceStatus.get_finished_reason( + seq.status) + output.stop_reason = seq.stop_reason + + else: + output = CompletionOutput( + seqs.index(seq), output_text, [output_token_ids] + if isinstance(output_token_ids, int) else output_token_ids, seq.get_cumulative_logprob() if include_logprobs else None, output_logprobs, SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason)) + seq.stop_reason) + + outputs.append(output) # Every sequence in the sequence group should have the same prompt. if include_prompt: @@ -188,16 +236,20 @@ def from_seq_group(cls, prompt_logprobs = None finished_time = time.time() if finished else None seq_group.set_finished_time(finished_time) - return cls(seq_group.request_id, - prompt, - prompt_token_ids, - prompt_logprobs, - outputs, - finished, - seq_group.metrics, - lora_request=seq_group.lora_request, - encoder_prompt=encoder_prompt, - encoder_prompt_token_ids=encoder_prompt_token_ids) + + init_args = (seq_group.request_id, prompt, prompt_token_ids, + prompt_logprobs, outputs, finished, seq_group.metrics, + seq_group.lora_request, encoder_prompt, + encoder_prompt_token_ids) + + if use_cache: + request_output = seq_group.cached_request_output + request_output.__init__(*init_args) # type: ignore + + else: + request_output = cls(*init_args) + + return request_output def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " @@ -261,10 +313,10 @@ def __repr__(self): class RequestOutputFactory: @staticmethod - def create(seq_group): + def create(seq_group: SequenceGroup, use_cache: bool = False): # Determine the type based on a condition, for example: if hasattr(seq_group, 'embeddings') and seq_group.embeddings is not None: return EmbeddingRequestOutput.from_seq_group(seq_group) else: - return RequestOutput.from_seq_group(seq_group) + return RequestOutput.from_seq_group(seq_group, use_cache) diff --git a/vllm/sequence.py b/vllm/sequence.py index d8e54ff1fc708..79e8a1f6244d7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -436,7 +436,7 @@ def __init__( self.stop_reason: Union[int, str, None] = None # These are used to keep track of delta outputs - self._last_token_ids_offset: int = 0 + self._last_output_token_ids_offset: int = 0 self._last_output_text_offset: int = 0 # Used for incremental detokenization @@ -499,18 +499,26 @@ def get_output_text_to_return(self, buffer_length: int, return self.output_text[last_offset:length] return "" - def get_output_token_ids_to_return(self, - delta: bool) -> GenericSequence[int]: + def get_output_token_ids_to_return( + self, delta: bool) -> Union[GenericSequence[int], int]: """If delta is True, only new tokens since the last call to this method are returned""" if not delta: return self.get_output_token_ids() - length = self.get_output_len() - last_offset = self._last_token_ids_offset - if last_offset < length: - self._last_token_ids_offset = length - return self.data._output_token_ids[last_offset:] - return () + + output_len = self.get_output_len() + + # Get the number of new tokens + num_new_tokens = output_len - self._last_output_token_ids_offset + self._last_output_token_ids_offset = output_len + + # Return new tokens + if num_new_tokens == 1: + # Optimization for single decode token case + # (which is what we have most of the time) + return self.data._cached_all_token_ids[-1] + + return self.data._cached_all_token_ids[-num_new_tokens:] def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size @@ -671,6 +679,8 @@ def __init__( self.encoder_seq = encoder_seq self.trace_headers = trace_headers + self.cached_request_output = None + @property def prompt(self) -> Optional[str]: # All sequences in the group should have the same prompt. From 530821d00cb2beeb8dc62f74f0e4e0003868dc93 Mon Sep 17 00:00:00 2001 From: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Date: Mon, 23 Sep 2024 21:52:39 -0400 Subject: [PATCH 07/50] [Hardware][AMD] ROCm6.2 upgrade (#8674) --- Dockerfile.rocm | 56 ++++++---------- .../getting_started/amd-installation.rst | 65 ++++++++++++------- 2 files changed, 61 insertions(+), 60 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index a12d5ba5fd8f5..9aa3a974e7046 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -1,5 +1,5 @@ -# Default ROCm 6.1 base image -ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging" +# Default ROCm 6.2 base image +ARG BASE_IMAGE="rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0" # Default ROCm ARCHes to build vLLM for. ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" @@ -7,18 +7,12 @@ ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" # Whether to install CK-based flash-attention # If 0, will not install flash-attention ARG BUILD_FA="1" -# If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL` -# If this succeeds, we use the downloaded wheel and skip building flash-attention. -# Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the -# architectures specified in `FA_GFX_ARCHS` -ARG TRY_FA_WHEEL="1" -ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl" ARG FA_GFX_ARCHS="gfx90a;gfx942" -ARG FA_BRANCH="23a2b1c2" +ARG FA_BRANCH="3cea2fb" # Whether to build triton on rocm ARG BUILD_TRITON="1" -ARG TRITON_BRANCH="e0fc12c" +ARG TRITON_BRANCH="e192dba" ### Base image build stage FROM $BASE_IMAGE AS base @@ -50,16 +44,17 @@ RUN python3 -m pip install --upgrade pip # Remove sccache so it doesn't interfere with ccache # TODO: implement sccache support across components RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)" -# Install torch == 2.5.0 on ROCm + +# Install torch == 2.6.0 on ROCm RUN --mount=type=cache,target=/root/.cache/pip \ case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-6.1"*) \ + *"rocm-6.2"*) \ python3 -m pip uninstall -y torch torchvision \ && python3 -m pip install --pre \ - torch==2.5.0.dev20240726 \ - cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ - torchvision==0.20.0.dev20240726 \ - --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.1 ;; \ + torch==2.6.0.dev20240918 \ + setuptools-scm>=8 \ + torchvision==0.20.0.dev20240918 \ + --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2;; \ *) ;; esac ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer @@ -81,25 +76,18 @@ RUN cd /opt/rocm/share/amd_smi \ ### Flash-Attention wheel build stage FROM base AS build_fa ARG BUILD_FA -ARG TRY_FA_WHEEL -ARG FA_WHEEL_URL ARG FA_GFX_ARCHS ARG FA_BRANCH # Build ROCm flash-attention wheel if `BUILD_FA = 1` RUN --mount=type=cache,target=${CCACHE_DIR} \ if [ "$BUILD_FA" = "1" ]; then \ - if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \ - # If a suitable wheel exists, we download it instead of building FA - mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \ - else \ - mkdir -p libs \ - && cd libs \ - && git clone https://github.com/ROCm/flash-attention.git \ - && cd flash-attention \ - && git checkout "${FA_BRANCH}" \ - && git submodule update --init \ - && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ - fi; \ + mkdir -p libs \ + && cd libs \ + && git clone https://github.com/ROCm/flash-attention.git \ + && cd flash-attention \ + && git checkout "${FA_BRANCH}" \ + && git submodule update --init \ + && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ # Create an empty directory otherwise as later build stages expect one else mkdir -p /install; \ fi @@ -114,6 +102,7 @@ RUN --mount=type=cache,target=${CCACHE_DIR} \ if [ "$BUILD_TRITON" = "1" ]; then \ mkdir -p libs \ && cd libs \ + && python3 -m pip install ninja cmake wheel pybind11 \ && git clone https://github.com/OpenAI/triton.git \ && cd triton \ && git checkout "${TRITON_BRANCH}" \ @@ -143,13 +132,6 @@ RUN --mount=type=cache,target=${CCACHE_DIR} \ --mount=type=bind,source=.git,target=.git \ --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -Ur requirements-rocm.txt \ - && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-6.1"*) \ - # Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM - wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib \ - # Prevent interference if torch bundles its own HIP runtime - && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \ - *) ;; esac \ && python3 setup.py clean --all \ && python3 setup.py develop diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index d169fe676dc94..4ed0bfe70071d 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -3,15 +3,17 @@ Installation with ROCm ====================== -vLLM supports AMD GPUs with ROCm 6.1. +vLLM supports AMD GPUs with ROCm 6.2. Requirements ------------ * OS: Linux -* Python: 3.8 -- 3.11 +* Python: 3.9 -- 3.12 * GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100) -* ROCm 6.1 +* ROCm 6.2 + +Note: PyTorch 2.5+/ROCm6.2 dropped the support for python 3.8. Installation options: @@ -27,7 +29,7 @@ You can build and install vLLM from source. First, build a docker image from `Dockerfile.rocm `_ and launch a docker container from the image. -`Dockerfile.rocm `_ uses ROCm 6.1 by default, but also supports ROCm 5.7 and 6.0 in older vLLM branches. +`Dockerfile.rocm `_ uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches. It provides flexibility to customize the build of docker image using the following arguments: * `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. @@ -39,13 +41,13 @@ It provides flexibility to customize the build of docker image using the followi Their values can be passed in when running ``docker build`` with ``--build-arg`` options. -To build vllm on ROCm 6.1 for MI200 and MI300 series, you can use the default: +To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default: .. code-block:: console $ DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm . -To build vllm on ROCm 6.1 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below: +To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below: .. code-block:: console @@ -79,9 +81,8 @@ Option 2: Build from source - `ROCm `_ - `PyTorch `_ -- `hipBLAS `_ -For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch-nightly`. +For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0`, `rocm/pytorch-nightly`. Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch `Getting Started `_ @@ -90,26 +91,45 @@ Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTor Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton `_ + .. code-block:: console + + $ python3 -m pip install ninja cmake wheel pybind11 + $ pip uninstall -y triton + $ git clone https://github.com/OpenAI/triton.git + $ cd triton + $ git checkout e192dba + $ cd python + $ pip3 install . + $ cd ../.. + +.. note:: + - If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent. + + 2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm `_ + Install ROCm's flash attention (v2.5.9.post1) following the instructions from `ROCm/flash-attention `_ Alternatively, wheels intended for vLLM use can be accessed under the releases. -.. note:: - - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) +For example, for ROCm 6.2, suppose your gfx arch is `gfx90a`. +Note to get your gfx architecture, run `rocminfo |grep gfx`. -3. Build vLLM. - -.. code-block:: console + .. code-block:: console - $ cd vllm - $ pip install -U -r requirements-rocm.txt - $ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .` does not work for ROCm installation + $ git clone https://github.com/ROCm/flash-attention.git + $ cd flash-attention + $ git checkout 3cea2fb + $ git submodule update --init + $ GPU_ARCHS="gfx90a" python3 setup.py install + $ cd .. +.. note:: + - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) -.. tip:: +3. Build vLLM. - For example, vLLM v0.5.3 on ROCM 6.1 can be built with the following steps: + For example, vLLM on ROCM 6.2 can be built with the following steps: .. code-block:: console @@ -117,7 +137,7 @@ Alternatively, wheels intended for vLLM use can be accessed under the releases. $ # Install PyTorch $ pip uninstall torch -y - $ pip install --no-cache-dir --pre torch==2.5.0.dev20240726 --index-url https://download.pytorch.org/whl/nightly/rocm6.1 + $ pip install --no-cache-dir --pre torch==2.6.0.dev20240918 --index-url https://download.pytorch.org/whl/nightly/rocm6.2 $ # Build & install AMD SMI $ pip install /opt/rocm/share/amd_smi @@ -127,15 +147,14 @@ Alternatively, wheels intended for vLLM use can be accessed under the releases. $ pip install "numpy<2" $ pip install -r requirements-rocm.txt - $ # Apply the patch to ROCM 6.1 (requires root permission) - $ wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib - $ rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* - $ # Build vLLM for MI210/MI250/MI300. $ export PYTORCH_ROCM_ARCH="gfx90a;gfx942" $ python3 setup.py develop + This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation + + .. tip:: - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. From 88577ac92808cfd9468e4b54b757d5fcbe9aa486 Mon Sep 17 00:00:00 2001 From: sroy745 <142070531+sroy745@users.noreply.github.com> Date: Mon, 23 Sep 2024 21:43:13 -0700 Subject: [PATCH 08/50] Fix tests in test_scheduler.py that fail with BlockManager V2 (#8728) --- tests/core/test_scheduler.py | 349 ++++++++++++++++++++++++++--------- 1 file changed, 260 insertions(+), 89 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 11168d2423b0e..b3bc00280682c 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -3,7 +3,8 @@ from typing import List, Set, Tuple from unittest.mock import MagicMock -import pytest # noqa +import pytest +from torch import Use # noqa from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus @@ -16,9 +17,11 @@ schedule_and_update_computed_tokens) -def test_scheduler_add_seq_group(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_add_seq_group(use_v2_block_manager: bool): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1) + scheduler_config = SchedulerConfig( + 100, 64, 1, use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -27,14 +30,18 @@ def test_scheduler_add_seq_group(): # Add seq group to scheduler. num_seq_group = 4 for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), block_size) + _, seq_group = create_dummy_prompt(str(i), + block_size, + block_size=block_size) scheduler.add_seq_group(seq_group) assert scheduler.get_num_unfinished_seq_groups() == i + 1 -def test_scheduler_abort_seq_group(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_abort_seq_group(use_v2_block_manager: bool): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1) + scheduler_config = SchedulerConfig( + 100, 64, 1, use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -54,11 +61,16 @@ def test_scheduler_abort_seq_group(): assert scheduler.get_num_unfinished_seq_groups() == 0 -def test_scheduler_schedule_simple(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_schedule_simple(use_v2_block_manager: bool): block_size = 4 num_seq_group = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) + scheduler_config = SchedulerConfig( + 64, + num_seq_group, + max_model_len, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -67,7 +79,9 @@ def test_scheduler_schedule_simple(): # Add seq groups to scheduler. for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -91,20 +105,24 @@ def test_scheduler_schedule_simple(): append_new_token(out, 1) -def test_scheduler_prefill_prioritized(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_prefill_prioritized(use_v2_block_manager: bool): """Verify running batched tokens are not applied to prefill requests.""" block_size = 4 max_model_len = 30 max_batched_num_tokens = 30 - scheduler_config = SchedulerConfig(max_batched_num_tokens, 2, - max_model_len) + scheduler_config = SchedulerConfig( + max_batched_num_tokens, + 2, + max_model_len, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 2 - cache_config.num_gpu_blocks = 2 + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) # Add seq groups to scheduler. - _, seq_group_a = create_dummy_prompt("1", 1) + _, seq_group_a = create_dummy_prompt("1", 1, block_size=block_size) scheduler.add_seq_group(seq_group_a) # Schedule seq groups prompts. @@ -112,7 +130,7 @@ def test_scheduler_prefill_prioritized(): assert get_sequence_groups(out) == [seq_group_a] # Add a new prefill request B. - _, seq_group_b = create_dummy_prompt("2", 30) + _, seq_group_b = create_dummy_prompt("2", 30, block_size=block_size) scheduler.add_seq_group(seq_group_b) # Verify prefill requests are prioritized. Since max_batched_num_tokens @@ -121,18 +139,24 @@ def test_scheduler_prefill_prioritized(): assert get_sequence_groups(out) == [seq_group_b] -def test_scheduler_schedule_preempt_abort(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_schedule_preempt_abort(use_v2_block_manager: bool): block_size = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, 2, max_model_len) + scheduler_config = SchedulerConfig( + 64, 2, max_model_len, use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 2 cache_config.num_gpu_blocks = 2 scheduler = Scheduler(scheduler_config, cache_config, None) # Add seq groups to scheduler. - seq_a, seq_group_a = create_dummy_prompt("1", block_size) - seq_b, seq_group_b = create_dummy_prompt("2", block_size) + seq_a, seq_group_a = create_dummy_prompt("1", + block_size, + block_size=block_size) + seq_b, seq_group_b = create_dummy_prompt("2", + block_size, + block_size=block_size) scheduler.add_seq_group(seq_group_a) scheduler.add_seq_group(seq_group_b) @@ -170,12 +194,17 @@ def test_scheduler_schedule_preempt_abort(): assert scheduler.get_num_unfinished_seq_groups() == 1 -def test_scheduler_max_seqs(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_max_seqs(use_v2_block_manager: bool): block_size = 4 num_seq_group = 4 max_seq_group = 2 max_model_len = 16 - scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len) + scheduler_config = SchedulerConfig( + 64, + max_seq_group, + max_model_len, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -184,7 +213,9 @@ def test_scheduler_max_seqs(): all_seq_groups: List[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + block_size=block_size) all_seq_groups.append(seq_group) # Append 1 seq group @@ -211,9 +242,15 @@ def test_scheduler_max_seqs(): assert set(get_sequence_groups(out)) == set([all_seq_groups[1]]) -def test_scheduler_delay_factor(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_delay_factor(use_v2_block_manager: bool): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 16, delay_factor=0.5) + scheduler_config = SchedulerConfig( + 100, + 64, + 16, + delay_factor=0.5, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -221,7 +258,8 @@ def test_scheduler_delay_factor(): # schedule first prompt seq_group_meta, seq_group = create_dummy_prompt("0", - prompt_length=block_size) + prompt_length=block_size, + block_size=block_size) scheduler.add_seq_group(seq_group) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert out.num_prefill_groups > 0 @@ -231,7 +269,8 @@ def test_scheduler_delay_factor(): # wait for a second before scheduling next prompt time.sleep(1) seq_group_meta, seq_group = create_dummy_prompt("1", - prompt_length=block_size) + prompt_length=block_size, + block_size=block_size) scheduler.add_seq_group(seq_group) # second prompt should *not* be scheduled @@ -248,11 +287,20 @@ def test_scheduler_delay_factor(): append_new_token(out, 1) -def test_swapped_out_prioritized(): - scheduler = initialize_scheduler(max_num_seqs=6) +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_swapped_out_prioritized(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(max_num_seqs=6, + block_size=block_size, + use_v2_block_manager=use_v2_block_manager, + num_cpu_blocks=64, + num_gpu_blocks=64) # best_of=2 * 3 == 6 sequences. for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + best_of=2, + block_size=block_size) scheduler.add_seq_group(seq_group) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) # prefill scheduled now. @@ -276,7 +324,10 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): append_new_token(out, 1) # Add 1 more task. Swap should be prioritized over prefill. - _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + best_of=2, + block_size=block_size) scheduler.add_seq_group(seq_group) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) append_new_token(out, 1) @@ -287,17 +338,26 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert out.blocks_to_swap_out == [] -def initialize_scheduler(*, - max_num_seqs=1000, - max_token_budget=1000, - max_model_len=1000, - lora_config=None): - block_size = 4 - scheduler_config = SchedulerConfig(max_token_budget, max_num_seqs, - max_model_len) +def initialize_scheduler( + *, + max_num_seqs=1000, + max_token_budget=1000, + max_model_len=1000, + lora_config=None, + use_v2_block_manager=False, + block_size=4, + num_cpu_blocks=8, + num_gpu_blocks=8, +): + block_size = block_size + scheduler_config = SchedulerConfig( + max_token_budget, + max_num_seqs, + max_model_len, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = num_cpu_blocks + cache_config.num_gpu_blocks = num_gpu_blocks scheduler = Scheduler(scheduler_config, cache_config, lora_config) return scheduler @@ -319,12 +379,18 @@ def add_token_budget(budget: SchedulingBudget, budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs) -def test_prefill_schedule_max_prompt_len(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prefill_schedule_max_prompt_len(use_v2_block_manager: bool): """ Test prompt longer than max_prompt_len is aborted. """ - scheduler = initialize_scheduler(max_model_len=30) - _, seq_group = create_dummy_prompt("0", prompt_length=60) + block_size = 4 + scheduler = initialize_scheduler(max_model_len=30, + use_v2_block_manager=use_v2_block_manager, + block_size=block_size) + _, seq_group = create_dummy_prompt("0", + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) budget = create_token_budget() output = scheduler._schedule_prefills(budget, None) @@ -336,14 +402,21 @@ def test_prefill_schedule_max_prompt_len(): assert len(remaining_waiting) == 0 -def test_prefill_schedule_token_budget(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prefill_schedule_token_budget(use_v2_block_manager: bool): """ Test token budget respected. """ - scheduler = initialize_scheduler() + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64) budget = create_token_budget(token_budget=0) for i in range(2): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) # 0 token budget == nothing is scheduled. @@ -366,10 +439,15 @@ def test_prefill_schedule_token_budget(): assert len(remaining_waiting) == 1 # Test when current_batched_tokens respected. - scheduler = initialize_scheduler() + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=16, + num_gpu_blocks=16) budget = create_token_budget(token_budget=60) add_token_budget(budget, 30, 0) - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) # Cannot schedule a prompt that doesn't fit the budget. scheduler.add_seq_group(seq_group) output = scheduler._schedule_prefills(budget, None) @@ -389,14 +467,21 @@ def test_prefill_schedule_token_budget(): assert len(remaining_waiting) == 0 -def test_prefill_schedule_max_seqs(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prefill_schedule_max_seqs(use_v2_block_manager: bool): """ Test max seq respected. """ - scheduler = initialize_scheduler() + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64) budget = create_token_budget(max_num_seqs=2) for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) output = scheduler._schedule_prefills(budget, None) remaining_waiting = scheduler.waiting @@ -410,7 +495,9 @@ def test_prefill_schedule_max_seqs(): scheduler.waiting = deque() budget = create_token_budget(max_num_seqs=2) add_token_budget(budget, 0, 2) - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) output = scheduler._schedule_prefills(budget, None) remaining_waiting = scheduler.waiting @@ -421,17 +508,24 @@ def test_prefill_schedule_max_seqs(): assert len(remaining_waiting) == 1 -def test_prefill_schedule_max_lora(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prefill_schedule_max_lora(use_v2_block_manager: bool): """ Test max lora is respected and prioritized. """ + block_size = 4 lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) - scheduler = initialize_scheduler(lora_config=lora_config) + scheduler = initialize_scheduler(lora_config=lora_config, + use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64) budget = create_token_budget(token_budget=120) curr_loras: Set[int] = set() for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, + block_size=block_size, lora_request=LoRARequest( lora_name=str(i), lora_int_id=i + 1, @@ -443,7 +537,9 @@ def test_prefill_schedule_max_lora(): # If a request is not scheduled because it hits max lora, it is # prioritized. Verify that. for i in range(2, 4): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) # Schedule 2 requests (0 and 2) output = scheduler._schedule_prefills(budget, curr_loras) @@ -467,14 +563,21 @@ def test_prefill_schedule_max_lora(): assert budget.num_batched_tokens == 60 -def test_prefill_schedule_no_block_manager_capacity(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prefill_schedule_no_block_manager_capacity(use_v2_block_manager): """ Test sequence cannot be scheduled due to block manager has no capacity. """ - scheduler = initialize_scheduler() + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_gpu_blocks=128, + num_cpu_blocks=128) budget = create_token_budget() for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) scheduler.block_manager.can_allocate = MagicMock() scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER @@ -489,7 +592,9 @@ def test_prefill_schedule_no_block_manager_capacity(): scheduler = initialize_scheduler() budget = create_token_budget() for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) scheduler.block_manager.can_allocate = MagicMock() scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER @@ -502,14 +607,21 @@ def test_prefill_schedule_no_block_manager_capacity(): assert len(remaining_waiting) == 0 -def test_decode_schedule_preempted(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_decode_schedule_preempted(use_v2_block_manager: bool): """ Test decodes cannot be scheduled and preempted. """ - scheduler = initialize_scheduler() + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64) curr_loras = None for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._add_seq_group_to_running(seq_group) @@ -541,15 +653,23 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert output.blocks_to_copy == [] -def test_decode_swap_beam_search(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_decode_swap_beam_search(use_v2_block_manager: bool): """ Test best_of > 1 swap out blocks """ - scheduler = initialize_scheduler() + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_gpu_blocks=64, + num_cpu_blocks=64) curr_loras = None budget = create_token_budget() for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + best_of=2, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) scheduler._add_seq_group_to_running(seq_group) append_new_token_seq_group(60, seq_group, 1) @@ -589,12 +709,20 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert output.blocks_to_copy == [] -def test_schedule_decode_blocks_to_copy_update(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_decode_blocks_to_copy_update(use_v2_block_manager: bool): """ Verify blocks_to_copy is updated. """ - scheduler = initialize_scheduler() - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=4, + num_cpu_blocks=16, + num_gpu_blocks=16) + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + best_of=2, + block_size=block_size) curr_loras = None scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) @@ -644,12 +772,17 @@ def test_schedule_swapped_simple(): assert blocks_to_swap_out == blocks_to_swap_in_reverse -def test_schedule_swapped_max_token_budget(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_max_token_budget(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] - for _ in range(2): - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) @@ -676,12 +809,19 @@ def test_schedule_swapped_max_token_budget(): assert len(output.prefill_seq_groups) == 0 -def test_schedule_swapped_max_seqs(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_max_seqs(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] for i in range(4): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=4) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) @@ -706,14 +846,21 @@ def test_schedule_swapped_max_seqs(): assert len(output.prefill_seq_groups) == 0 -def test_schedule_swapped_max_loras(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_max_loras(use_v2_block_manager: bool): + block_size = 4 lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) - scheduler = initialize_scheduler(lora_config=lora_config) + scheduler = initialize_scheduler(lora_config=lora_config, + use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32) curr_loras: Set[int] = set() blocks_to_swap_out: List[Tuple[int, int]] = [] for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, + block_size=block_size, lora_request=LoRARequest( lora_name=str(i), lora_int_id=i + 1, @@ -734,12 +881,20 @@ def test_schedule_swapped_max_loras(): assert len(curr_loras) == 1 -def test_schedule_swapped_cannot_swap_in(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_cannot_swap_in(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] - for _ in range(2): - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + best_of=2, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) @@ -759,12 +914,20 @@ def test_schedule_swapped_cannot_swap_in(): assert len(output.prefill_seq_groups) == 0 -def test_infeasible_swap(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_infeasible_swap(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] - for _ in range(2): - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + best_of=2, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) @@ -785,10 +948,18 @@ def test_infeasible_swap(): assert len(output.prefill_seq_groups) == 0 -def test_schedule_swapped_blocks_to_copy(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_blocks_to_copy(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32) curr_loras = None - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + best_of=2, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) blocks_to_swap_out: List[Tuple[int, int]] = [] From 0250dd68c5df12ead29d2ec7d922855c9a257b06 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 23 Sep 2024 22:08:12 -0700 Subject: [PATCH 09/50] re-implement beam search on top of vllm core (#8726) Co-authored-by: Brendan Wong --- benchmarks/benchmark_throughput.py | 24 ++++- tests/conftest.py | 14 +++ tests/samplers/test_beam_search.py | 6 +- vllm/entrypoints/llm.py | 136 ++++++++++++++++++++++++++++- 4 files changed, 171 insertions(+), 9 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index e1a5d4ee28ea1..68b401d5bbbb7 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -90,6 +90,7 @@ def run_vllm( download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, disable_async_output_proc: bool = False, + use_new_beam_search_impl: bool = False, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -132,9 +133,23 @@ def run_vllm( max_tokens=output_len, )) - start = time.perf_counter() - llm.generate(prompts, sampling_params, use_tqdm=True) - end = time.perf_counter() + if not use_new_beam_search_impl: + start = time.perf_counter() + llm.generate(prompts, sampling_params, use_tqdm=True) + end = time.perf_counter() + else: + assert use_beam_search + prompts = [prompt for prompt, _, _ in requests] + # output_len should be the same for all requests. + output_len = requests[0][2] + for prompt, input_len, _output_len in requests: + assert _output_len == output_len + start = time.perf_counter() + llm.beam_search(prompts, + beam_width=n, + max_tokens=output_len, + ignore_eos=True) + end = time.perf_counter() return end - start @@ -336,7 +351,7 @@ def main(args: argparse.Namespace): run_args.append(args.disable_frontend_multiprocessing) elapsed_time = uvloop.run(run_vllm_async(*run_args)) else: - elapsed_time = run_vllm(*run_args) + elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -396,6 +411,7 @@ def main(args: argparse.Namespace): default=1, help="Number of generated sequences per prompt.") parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--use-new-beam-search-impl", action="store_true") parser.add_argument("--num-prompts", type=int, default=1000, diff --git a/tests/conftest.py b/tests/conftest.py index c2616bcf7091c..69ac4aaee0fda 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -798,6 +798,20 @@ def generate_beam_search( outputs = self.generate(prompts, beam_search_params) return outputs + def generate_beam_search_new( + self, + prompts: Union[List[str], List[List[int]]], + beam_width: int, + max_tokens: int, + ) -> List[Tuple[List[List[int]], List[str]]]: + outputs = self.model.beam_search(prompts, beam_width, max_tokens) + returned_outputs = [] + for output in outputs: + token_ids = [x.tokens for x in output.sequences] + texts = [x.text for x in output.sequences] + returned_outputs.append((token_ids, texts)) + return returned_outputs + def encode(self, prompts: List[str]) -> List[List[float]]: req_outputs = self.model.encode(prompts) outputs = [] diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index 98a02dec895d2..a9bedc2956fdd 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -9,7 +9,7 @@ # 1. Increase max_tokens to 256. # 2. Increase beam_width to 8. # 3. Use the model "huggyllama/llama-7b". -MAX_TOKENS = [128] +MAX_TOKENS = [64] BEAM_WIDTHS = [4] MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"] @@ -33,8 +33,8 @@ def test_beam_search_single_input( max_tokens) with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_beam_search(example_prompts, - beam_width, max_tokens) + vllm_outputs = vllm_model.generate_beam_search_new( + example_prompts, beam_width, max_tokens) for i in range(len(example_prompts)): hf_output_ids, hf_output_texts = hf_outputs[i] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a86c51d23b34d..387813f374daa 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,6 +1,8 @@ +import itertools from contextlib import contextmanager -from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast, - overload) +from dataclasses import dataclass +from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, + Union, cast, overload) from tqdm import tqdm @@ -30,6 +32,37 @@ logger = init_logger(__name__) +@dataclass +class BeamSearchSequence: + """A sequence for beam search. + It keeps track of the tokens and the log probability of the sequence. + The text field is optional and will only be filled when the sequence is + about to be returned to the user. + """ + # The tokens includes the prompt. + tokens: List[int] + cum_logprob: float = 0.0 + text: Optional[str] = None + + +@dataclass +class BeamSearchOutput: + """The output of beam search. + It contains the list of the best beam search sequences. + The length of the list is equal to the beam width. + """ + sequences: List[BeamSearchSequence] + + +class BeamSearchInstance: + + def __init__(self, prompt_tokens: List[int]): + self.beams: List[BeamSearchSequence] = [ + BeamSearchSequence(tokens=prompt_tokens) + ] + self.completed: List[BeamSearchSequence] = [] + + class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -354,6 +387,105 @@ def generate( outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) + def beam_search( + self, + prompts: List[Union[str, List[int]]], + beam_width: int, + max_tokens: int, + ignore_eos: bool = False, + ) -> List[BeamSearchOutput]: + """ + Generate sequences using beam search. + + Args: + prompts: A list of prompts. Each prompt can be a string or a list + of token IDs. + beam_width: The number of beams to keep at each step. + max_tokens: The max number of tokens to generate for each prompt. + + TODO: how does beam search work together with length penalty, frequency + penalty, and stopping criteria, etc.? + """ + + tokenizer = self.get_tokenizer() + # generate 2 * beam_width candidates at each step + # following the huggingface transformers implementation + # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=0.0) + instances: List[BeamSearchInstance] = [] + + for prompt in prompts: + prompt_tokens = prompt if isinstance( + prompt, list) else tokenizer.encode(prompt) + instances.append(BeamSearchInstance(prompt_tokens)) + + for _ in range(max_tokens): + all_beams: List[BeamSearchSequence] = list( + sum((instance.beams for instance in instances), [])) + pos = [0] + list( + itertools.accumulate( + len(instance.beams) for instance in instances)) + instance_start_and_end: List[Tuple[int, int]] = list( + zip(pos[:-1], pos[1:])) + + if len(all_beams) == 0: + break + + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens) + for beam in all_beams + ] + + # only runs for one step + # we don't need to use tqdm here + output = self.generate(prompts_batch, + sampling_params=beam_search_params, + use_tqdm=False) + + for (start, end), instance in zip(instance_start_and_end, + instances): + instance_new_beams = [] + for i in range(start, end): + current_beam = all_beams[i] + result = output[i] + + if result.outputs[0].logprobs is not None: + # if `result.outputs[0].logprobs` is None, it means + # the sequence is completed because of the max-model-len + # or abortion. we don't need to add it to the new beams. + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + instance.completed.append(new_beam) + else: + instance_new_beams.append(new_beam) + sorted_beams = sorted(instance_new_beams, + key=lambda x: x.cum_logprob, + reverse=True) + instance.beams = sorted_beams[:beam_width] + + outputs = [] + for instance in instances: + instance.completed.extend(instance.beams) + sorted_completed = sorted(instance.completed, + key=lambda x: x.cum_logprob, + reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens) + outputs.append(BeamSearchOutput(sequences=best_beams)) + + return outputs + def chat( self, messages: List[ChatCompletionMessageParam], From 3185fb0ccae73816018d0936c03171b7cf1ba2f8 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 23 Sep 2024 22:45:20 -0700 Subject: [PATCH 10/50] Revert "[Core] Rename `PromptInputs` to `PromptType`, and `inputs` to `prompt`" (#8750) --- benchmarks/benchmark_latency.py | 8 +- .../dev/multimodal/multimodal_index.rst | 2 +- .../dev/offline_inference/llm_inputs.rst | 2 +- docs/source/models/vlm.rst | 2 +- tests/mq_llm_engine/test_error_handling.py | 12 +-- tests/mq_llm_engine/utils.py | 2 +- vllm/__init__.py | 4 +- vllm/engine/async_llm_engine.py | 24 +++--- vllm/engine/llm_engine.py | 9 +- vllm/engine/multiprocessing/__init__.py | 4 +- vllm/engine/multiprocessing/client.py | 20 +++-- vllm/engine/multiprocessing/engine.py | 2 +- vllm/engine/protocol.py | 8 +- vllm/entrypoints/llm.py | 80 ++++++++--------- vllm/inputs/__init__.py | 6 +- vllm/inputs/data.py | 26 +++--- vllm/inputs/parse.py | 22 ++--- vllm/inputs/preprocess.py | 86 +++++++++---------- 18 files changed, 162 insertions(+), 157 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index eadf994cacd34..a39d1cf842f06 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -61,7 +61,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_prompts: List[PromptType] = [{ + dummy_inputs: List[PromptInputs] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] @@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(dummy_prompts, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(dummy_prompts, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index e112b43aade5e..241b2ccd0991e 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -8,7 +8,7 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` -via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`. +via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`. Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities by following :ref:`this guide `. diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst index 0d47281db485e..9adf82d43f3e0 100644 --- a/docs/source/dev/offline_inference/llm_inputs.rst +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -1,7 +1,7 @@ LLM Inputs ========== -.. autodata:: vllm.inputs.PromptType +.. autodata:: vllm.inputs.PromptInputs .. autoclass:: vllm.inputs.TextPrompt :show-inheritance: diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index ca5b125369c85..08db891665044 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model. -To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`: +To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: * ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 7c466c92d5293..49cfc5aa04c36 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket): # Throws an error in first forward pass. with pytest.raises(RAISED_ERROR): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket): # Engine is errored, should get ENGINE_DEAD_ERROR. with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket): # Generate call should throw ENGINE_DEAD_ERROR with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -165,7 +165,7 @@ async def bad_abort_after_2s(): # with reference to the original KeyError("foo") with pytest.raises(MQEngineDeadError) as execinfo: async for _ in client.generate( - prompt="Hello my name is", + inputs="Hello my name is", sampling_params=SamplingParams(max_tokens=2000), request_id=uuid.uuid4()): pass @@ -190,7 +190,7 @@ async def test_bad_request(tmp_socket): # Invalid request should fail, but not crash the server. with pytest.raises(ValueError): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-1", lora_request=LoRARequest( @@ -199,7 +199,7 @@ async def test_bad_request(tmp_socket): pass # This request should be okay. - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-2"): pass diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index 3ffa126070ca0..e27fd77923412 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -20,7 +20,7 @@ async def generate( count = 0 async for out in client.generate( request_id=request_id, - prompt="Hello my name is Robert and", + inputs="Hello my name is Robert and", sampling_params=SamplingParams(max_tokens=num_tokens, temperature=0)): diff --git a/vllm/__init__.py b/vllm/__init__.py index 8f477ea84756d..90363b3e49b73 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,7 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) @@ -19,7 +19,7 @@ "__version_tuple__", "LLM", "ModelRegistry", - "PromptType", + "PromptInputs", "TextPrompt", "TokensPrompt", "SamplingParams", diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f108751056ab5..34e7e05341f02 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -17,7 +17,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -405,7 +405,7 @@ async def stop_remote_worker_execution_loop_async(self) -> None: async def add_request_async( self, request_id: str, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -420,7 +420,7 @@ async def add_request_async( arrival_time = time.time() preprocessed_inputs = await self.input_preprocessor.preprocess_async( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -777,7 +777,7 @@ async def run_engine_loop(engine_ref: ReferenceType): async def add_request( self, request_id: str, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -797,7 +797,7 @@ async def add_request( stream = self._request_tracker.add_request( request_id, verbose=self.log_requests, - prompt=prompt, + inputs=inputs, params=params, arrival_time=arrival_time or time.time(), lora_request=lora_request, @@ -808,7 +808,7 @@ async def add_request( async def generate( self, - prompt: PromptType, + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -822,7 +822,8 @@ async def generate( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -880,7 +881,7 @@ async def generate( """ async for output in await self.add_request( request_id, - prompt, + inputs, sampling_params, lora_request=lora_request, trace_headers=trace_headers, @@ -890,7 +891,7 @@ async def generate( async def encode( self, - prompt: PromptType, + inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -903,7 +904,8 @@ async def encode( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -957,7 +959,7 @@ async def encode( """ async for output in await self.add_request( request_id, - prompt, + inputs, pooling_params, lora_request=lora_request, trace_headers=trace_headers, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1e77a01bfa9d9..bd7b3250e31af 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -29,7 +29,7 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptType) + InputRegistry, LLMInputs, PromptInputs) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -689,7 +689,7 @@ def stop_remote_worker_execution_loop(self) -> None: def add_request( self, request_id: str, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -704,7 +704,8 @@ def add_request( Args: request_id: The unique ID of the request. - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. params: Parameters for sampling or pooling. :class:`~vllm.SamplingParams` for text generation. @@ -744,7 +745,7 @@ def add_request( arrival_time = time.time() preprocessed_inputs = self.input_preprocessor.preprocess( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 09aa279f1e22c..700332864d17a 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -3,7 +3,7 @@ from typing import List, Mapping, Optional, Union from vllm import PoolingParams -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest @@ -23,7 +23,7 @@ class MQEngineDeadError(RuntimeError): @dataclass class RPCProcessRequest: - prompt: PromptType + inputs: PromptInputs params: Union[SamplingParams, PoolingParams] request_id: str lora_request: Optional[LoRARequest] = None diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 71099115ea125..aa9dbbd448af2 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -25,7 +25,7 @@ RPCStartupResponse) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -375,7 +375,7 @@ def dead_error(self) -> BaseException: def generate( self, - prompt: PromptType, + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -389,7 +389,8 @@ def generate( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -398,13 +399,13 @@ def generate( prompt_adapter_request: Prompt Adapter request to use for generation, if any. """ - return self._process_request(prompt, sampling_params, request_id, + return self._process_request(inputs, sampling_params, request_id, lora_request, trace_headers, prompt_adapter_request) def encode( self, - prompt: PromptType, + inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -417,7 +418,8 @@ def encode( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -428,12 +430,12 @@ def encode( The output `EmbeddingRequestOutput` objects from the LLMEngine for the request. """ - return self._process_request(prompt, pooling_params, request_id, + return self._process_request(inputs, pooling_params, request_id, lora_request, trace_headers) async def _process_request( self, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], request_id: str, lora_request: Optional[LoRARequest] = None, @@ -466,7 +468,7 @@ async def _process_request( request_bytes = pickle.dumps( RPCProcessRequest( - prompt=prompt, + inputs=inputs, params=params, request_id=request_id, lora_request=lora_request, diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 3b0f617629d63..485db0bab1297 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -252,7 +252,7 @@ def _handle_process_request(self, request: RPCProcessRequest): try: self.engine.add_request( request_id=request_id, - prompt=request.prompt, + inputs=request.inputs, params=request.params, lora_request=request.lora_request, trace_headers=request.trace_headers, diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index d0bbeb357b506..70444faa670a2 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -3,7 +3,7 @@ from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.inputs.data import PromptType +from vllm.inputs.data import PromptInputs from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -35,19 +35,19 @@ def dead_error(self) -> BaseException: def generate( self, - prompt: PromptType, + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: - """Generate outputs for a request.""" + """Generates outputs for a request""" ... def encode( self, - prompt: PromptType, + inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 387813f374daa..ca80dedd29ebd 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -12,7 +12,7 @@ apply_hf_chat_template, apply_mistral_chat_template, parse_chat_messages) -from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -293,8 +293,8 @@ def generate( @overload def generate( self, - prompts: Union[PromptType, Sequence[PromptType]], - /, + inputs: Union[PromptInputs, Sequence[PromptInputs]], + /, # We may enable `inputs` keyword after removing the old API *, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -311,7 +311,7 @@ def generate( ) def generate( self, - prompts: Union[Union[PromptType, Sequence[PromptType]], + prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], Optional[Union[str, List[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -329,9 +329,7 @@ def generate( into a single list and pass it to this method. Args: - prompts: The prompts to the LLM. You may pass a sequence of prompts - for batch inference. See :class:`~vllm.inputs.PromptType` - for more details about the format of each prompts. + inputs: A list of inputs to generate completions for. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. @@ -357,13 +355,12 @@ def generate( "models (XForCausalLM, XForConditionalGeneration).") if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( + inputs = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) if isinstance(guided_options_request, dict): if len(guided_options_request) > 1: @@ -378,7 +375,7 @@ def generate( sampling_params = SamplingParams() self._validate_and_add_requests( - prompts=parsed_prompts, + inputs=inputs, params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -533,9 +530,9 @@ def chat( conversation, mm_data = parse_chat_messages(messages, model_config, tokenizer) - prompt_data: Union[str, List[int]] + prompt: Union[str, List[int]] if isinstance(tokenizer, MistralTokenizer): - prompt_data = apply_mistral_chat_template( + prompt = apply_mistral_chat_template( tokenizer, messages=messages, chat_template=chat_template, @@ -543,7 +540,7 @@ def chat( tools=tools, ) else: - prompt_data = apply_hf_chat_template( + prompt = apply_hf_chat_template( tokenizer, conversation=conversation, chat_template=chat_template, @@ -551,17 +548,17 @@ def chat( tools=tools, ) - prompt: PromptType - if is_list_of(prompt_data, int): - prompt = TokensPrompt(prompt_token_ids=prompt_data) + inputs: PromptInputs + if is_list_of(prompt, int): + inputs = TokensPrompt(prompt_token_ids=prompt) else: - prompt = TextPrompt(prompt=prompt_data) + inputs = TextPrompt(prompt=prompt) if mm_data is not None: - prompt["multi_modal_data"] = mm_data + inputs["multi_modal_data"] = mm_data return self.generate( - prompt, + inputs, sampling_params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, @@ -631,8 +628,8 @@ def encode( @overload def encode( self, - prompts: Union[PromptType, Sequence[PromptType]], - /, + inputs: Union[PromptInputs, Sequence[PromptInputs]], + /, # We may enable `inputs` keyword after removing the old API *, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -649,7 +646,7 @@ def encode( ) def encode( self, - prompts: Union[Union[PromptType, Sequence[PromptType]], + prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], Optional[Union[str, List[str]]]] = None, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -665,9 +662,9 @@ def encode( into a single list and pass it to this method. Args: - prompts: The prompts to the LLM. You may pass a sequence of prompts - for batch inference. See :class:`~vllm.inputs.PromptType` - for more details about the format of each prompts. + inputs: The inputs to the LLM. You may pass a sequence of inputs for + batch inference. See :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. @@ -690,20 +687,19 @@ def encode( ) if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( + inputs = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() self._validate_and_add_requests( - prompts=parsed_prompts, + inputs=inputs, params=pooling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -747,9 +743,9 @@ def _convert_v1_inputs( raise ValueError("Either prompts or prompt_token_ids must be " "provided.") - parsed_prompts: List[PromptType] = [] + inputs: List[PromptInputs] = [] for i in range(num_requests): - item: PromptType + item: PromptInputs if prompts is not None: item = TextPrompt(prompt=prompts[i]) @@ -758,24 +754,24 @@ def _convert_v1_inputs( else: raise AssertionError - parsed_prompts.append(item) + inputs.append(item) - return parsed_prompts + return inputs def _validate_and_add_requests( self, - prompts: Union[PromptType, Sequence[PromptType]], + inputs: Union[PromptInputs, Sequence[PromptInputs]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], guided_options: Optional[GuidedDecodingRequest] = None, ) -> None: - if isinstance(prompts, (str, dict)): + if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. - prompts = [prompts] + inputs = [inputs] - num_requests = len(prompts) + num_requests = len(inputs) if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") @@ -792,9 +788,9 @@ def _validate_and_add_requests( sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. - for i, prompt in enumerate(prompts): + for i, request_inputs in enumerate(inputs): self._add_request( - prompt, + request_inputs, params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, @@ -803,7 +799,7 @@ def _validate_and_add_requests( def _add_request( self, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -811,7 +807,7 @@ def _add_request( request_id = str(next(self.request_counter)) self.llm_engine.add_request( request_id, - prompt, + inputs, params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index ba1bef1ab3ecc..0b08e9691f915 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,5 +1,5 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptType, SingletonPrompt, TextPrompt, + LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry @@ -16,8 +16,8 @@ __all__ = [ "TextPrompt", "TokensPrompt", - "PromptType", - "SingletonPrompt", + "PromptInputs", + "SingletonPromptInputs", "ExplicitEncoderDecoderPrompt", "LLMInputs", "EncoderDecoderLLMInputs", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index e072bb65714b9..75ab0c770155b 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -33,7 +33,7 @@ class TokensPrompt(TypedDict): """ -SingletonPrompt = Union[str, TextPrompt, TokensPrompt] +SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt] """ Set of possible schemas for a single LLM input: @@ -46,7 +46,7 @@ class TokensPrompt(TypedDict): the user desires to express both the encoder & decoder prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` -A prompt of type :class:`SingletonPromptType` may be employed +A prompt of type :class:`SingletonPromptInputs` may be employed as (1) input to a decoder-only model, (2) input to the encoder of an encoder/decoder model, in the scenario where the decoder-prompt is not specified explicitly, or @@ -55,12 +55,12 @@ class TokensPrompt(TypedDict): """ _T1_co = TypeVar("_T1_co", - bound=SingletonPrompt, - default=SingletonPrompt, + bound=SingletonPromptInputs, + default=SingletonPromptInputs, covariant=True) _T2_co = TypeVar("_T2_co", - bound=SingletonPrompt, - default=SingletonPrompt, + bound=SingletonPromptInputs, + default=SingletonPromptInputs, covariant=True) @@ -72,7 +72,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): The encoder and decoder prompts, respectively, may formatted according to any of the - :class:`SingletonPromptType` schemas, and are not + :class:`SingletonPromptInputs` schemas, and are not required to have the same schema. Only the encoder prompt may have multi-modal data. @@ -81,7 +81,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): be used as an input to a decoder-only model, and that the `encoder_prompt` and `decoder_prompt` fields of this data structure themselves must be - :class:`SingletonPromptType` instances. + :class:`SingletonPromptInputs` instances. """ encoder_prompt: _T1_co @@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): decoder_prompt: Optional[_T2_co] -PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] +PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] """ Set of possible schemas for an LLM input, including both decoder-only and encoder/decoder input types: @@ -140,8 +140,12 @@ class EncoderDecoderLLMInputs(LLMInputs): """ -_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) -_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) +_T1 = TypeVar("_T1", + bound=SingletonPromptInputs, + default=SingletonPromptInputs) +_T2 = TypeVar("_T2", + bound=SingletonPromptInputs, + default=SingletonPromptInputs) def build_explicit_enc_dec_prompt( diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index e5fa1e4184277..ac9d355c64c80 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -5,7 +5,7 @@ from vllm.utils import is_list_of from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptType, SingletonPrompt, TextPrompt, + LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, TokensPrompt) @@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict): def parse_singleton_prompt( - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: - if isinstance(prompt, str): - return ParsedStrPrompt(type="str", content=prompt) - elif isinstance(prompt, dict): - if "prompt_token_ids" in prompt: + if isinstance(inputs, str): + return ParsedStrPrompt(type="str", content=inputs) + elif isinstance(inputs, dict): + if "prompt_token_ids" in inputs: return ParsedTokensPrompt(type="tokens", - content=prompt) # type: ignore - elif "prompt" in prompt: - return ParsedTextPrompt(type="text", content=prompt) + content=inputs) # type: ignore + elif "prompt" in inputs: + return ParsedTextPrompt(type="text", content=inputs) raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") def is_explicit_encoder_decoder_prompt( - prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: - return isinstance(prompt, dict) and "encoder_prompt" in prompt + inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]: + return isinstance(inputs, dict) and "encoder_prompt" in inputs def is_valid_encoder_decoder_llm_inputs( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 1f1b048d37e9b..be2aa5f8cb7d0 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -9,8 +9,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType, - SingletonPrompt) +from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, + SingletonPromptInputs) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt if TYPE_CHECKING: @@ -206,7 +206,7 @@ async def _tokenize_prompt_async( def _extract_prompt_components( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: @@ -216,7 +216,7 @@ def _extract_prompt_components( Arguments: * request_id - * prompt: single encoder or decoder input prompt + * inputs: single encoder or decoder input prompt * lora_request: this is only valid for decoder prompts Returns: @@ -226,24 +226,24 @@ def _extract_prompt_components( * multi_modal_data ''' - parsed = parse_singleton_prompt(prompt) + parsed = parse_singleton_prompt(inputs) if parsed["type"] == "str": - prompt_text = parsed["content"] + prompt = parsed["content"] prompt_token_ids = self._tokenize_prompt( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt_text = None + prompt = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt_text = parsed["content"]["prompt"] + prompt = parsed["content"]["prompt"] prompt_token_ids = self._tokenize_prompt( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -251,33 +251,33 @@ def _extract_prompt_components( else: assert_never(parsed) - return prompt_text, prompt_token_ids, multi_modal_data + return prompt, prompt_token_ids, multi_modal_data async def _extract_prompt_components_async( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: """Async version of :meth:`_extract_prompt_components`.""" - parsed = parse_singleton_prompt(prompt) + parsed = parse_singleton_prompt(inputs) if parsed["type"] == "str": - prompt_text = parsed["content"] + prompt = parsed["content"] prompt_token_ids = await self._tokenize_prompt_async( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt_text = None + prompt = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt_text = parsed["content"]["prompt"] + prompt = parsed["content"]["prompt"] prompt_token_ids = await self._tokenize_prompt_async( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -285,7 +285,7 @@ async def _extract_prompt_components_async( else: assert_never(parsed) - return prompt_text, prompt_token_ids, multi_modal_data + return prompt, prompt_token_ids, multi_modal_data def _build_enc_dec_llm_inputs( self, @@ -311,7 +311,7 @@ def _build_enc_dec_llm_inputs( def _process_encoder_decoder_prompt( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, ) -> EncoderDecoderLLMInputs: ''' @@ -339,7 +339,7 @@ def _process_encoder_decoder_prompt( Arguments: - * prompt: an input prompt + * inputs: an input prompt * request_id Returns: @@ -350,13 +350,13 @@ def _process_encoder_decoder_prompt( encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): encoder_comps = self._extract_prompt_components( - prompt["encoder_prompt"], + inputs["encoder_prompt"], request_id=request_id, ) - if (decoder_input := prompt["decoder_prompt"]) is None: + if (decoder_input := inputs["decoder_prompt"]) is None: decoder_comps = None, None, None else: decoder_comps = self._extract_prompt_components( @@ -365,7 +365,7 @@ def _process_encoder_decoder_prompt( ) else: encoder_comps = self._extract_prompt_components( - prompt, + inputs, request_id=request_id, ) @@ -375,20 +375,20 @@ def _process_encoder_decoder_prompt( async def _process_encoder_decoder_prompt_async( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, ) -> EncoderDecoderLLMInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): encoder_task = self._extract_prompt_components_async( - prompt["encoder_prompt"], + inputs["encoder_prompt"], request_id=request_id, ) - if (decoder_input := prompt["decoder_prompt"]) is None: + if (decoder_input := inputs["decoder_prompt"]) is None: encoder_comps = await encoder_task decoder_comps = None, None, None else: @@ -401,7 +401,7 @@ async def _process_encoder_decoder_prompt_async( encoder_task, decoder_task) else: encoder_comps = await self._extract_prompt_components_async( - prompt, + inputs, request_id=request_id, ) @@ -425,7 +425,7 @@ def _build_decoder_only_llm_inputs( def _process_decoder_only_prompt( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -436,7 +436,7 @@ def _process_decoder_only_prompt( Arguments: - * prompt: input prompt + * inputs: input prompt * request_id * lora_request * prompt_adapter_request @@ -447,7 +447,7 @@ def _process_decoder_only_prompt( ''' prompt_comps = self._extract_prompt_components( - prompt, + inputs, request_id=request_id, lora_request=lora_request, ) @@ -459,14 +459,14 @@ def _process_decoder_only_prompt( async def _process_decoder_only_prompt_async( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" prompt_comps = await self._extract_prompt_components_async( - prompt, + inputs, request_id=request_id, lora_request=lora_request, ) @@ -478,7 +478,7 @@ async def _process_decoder_only_prompt_async( def preprocess( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -488,17 +488,17 @@ def preprocess( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return self._process_encoder_decoder_prompt( - prompt, + inputs, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return self._process_decoder_only_prompt( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -506,7 +506,7 @@ def preprocess( async def preprocess_async( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -516,17 +516,17 @@ async def preprocess_async( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return await self._process_encoder_decoder_prompt_async( - prompt, + inputs, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return await self._process_decoder_only_prompt_async( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, From b8747e8a7c318ab774862f94ccbdbba5b7d9dd4a Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 23 Sep 2024 23:10:03 -0700 Subject: [PATCH 11/50] [MISC] Skip dumping inputs when unpicklable (#8744) --- vllm/worker/model_runner_base.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 975b88c0e79a2..86883cf152449 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -137,7 +137,15 @@ def _wrapper(*args, **kwargs): for t in kv_caches if is_tensor(t)] - pickle.dump(dumped_inputs, filep) + try: + pickle.dump(dumped_inputs, filep) + except Exception as pickle_err: + logger.warning( + "Failed to pickle inputs of failed execution: %s", + str(pickle_err)) + raise type(err)(f"Error in model execution: " + f"{str(err)}") from err + logger.info( "Completed writing input of failed execution to %s.", filename) From 3f06bae9079ee495a34cfadcd9c1ef2a23636084 Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Tue, 24 Sep 2024 00:14:15 -0700 Subject: [PATCH 12/50] [Core][Model] Support loading weights by ID within models (#7931) --- vllm/model_executor/model_loader/loader.py | 60 +++++++++++++++++----- vllm/model_executor/models/ultravox.py | 30 +++++++++-- 2 files changed, 73 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index f0d2a9e7f06be..aea3354cada90 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1,6 +1,7 @@ # ruff: noqa: SIM117 import collections import copy +import dataclasses import fnmatch import glob import json @@ -8,7 +9,8 @@ import os from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional, Tuple, Type +from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple, + Type, cast) import gguf import huggingface_hub @@ -207,6 +209,22 @@ def load_model(self, *, model_config: ModelConfig, class DefaultModelLoader(BaseModelLoader): """Model loader that can load different file types from disk.""" + @dataclasses.dataclass + class Source: + """A source for weights.""" + + model_or_path: str + """The model ID or path.""" + + revision: Optional[str] + """The optional model revision.""" + + prefix: str = "" + """A prefix to prepend to all weights.""" + + fall_back_to_pt: bool = True + """Whether .pt weights can be used.""" + def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: @@ -313,17 +331,16 @@ def _prepare_weights(self, model_name_or_path: str, return hf_folder, hf_weights_files, use_safetensors def _get_weights_iterator( - self, model_name_or_path: str, revision: Optional[str], - fall_back_to_pt: bool + self, source: "Source" ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( - model_name_or_path, revision, fall_back_to_pt) + source.model_or_path, source.revision, source.fall_back_to_pt) if self.load_config.load_format == LoadFormat.NPCACHE: # Currently np_cache only support *.bin checkpoints assert use_safetensors is False weights_iterator = np_cache_weights_iterator( - model_name_or_path, self.load_config.download_dir, hf_folder, + source.model_or_path, self.load_config.download_dir, hf_folder, hf_weights_files) elif use_safetensors: weights_iterator = safetensors_weights_iterator(hf_weights_files) @@ -341,7 +358,29 @@ def _xla_weights_iterator(iterator: Generator): xm.mark_step() weights_iterator = _xla_weights_iterator(weights_iterator) - return weights_iterator + + # Apply the prefix. + return ((source.prefix + name, tensor) + for (name, tensor) in weights_iterator) + + def _get_all_weights( + self, + model_config: ModelConfig, + model: nn.Module, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + + primary_weights = DefaultModelLoader.Source( + model_config.model, + model_config.revision, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", + True)) + yield from self._get_weights_iterator(primary_weights) + + secondary_weights = cast(Iterable[DefaultModelLoader.Source], + getattr(model, "secondary_weights", ())) + for source in secondary_weights: + yield from self._get_weights_iterator(source) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, @@ -360,13 +399,8 @@ def load_model(self, *, model_config: ModelConfig, model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) - model.load_weights( - self._get_weights_iterator(model_config.model, - model_config.revision, - fall_back_to_pt=getattr( - model, - "fall_back_to_pt_during_load", - True)), ) + + model.load_weights(self._get_all_weights(model_config, model)) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 32a0e895005cb..71808eb4c2719 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.utils import (flatten_bn, @@ -334,14 +335,23 @@ def __init__(self, self.multi_modal_config = multimodal_config assert self.multi_modal_config + self.secondary_weights = [] + self.audio_tower = ModifiedWhisperEncoder(config.audio_config) if config.audio_model_id is not None: - self.audio_tower = ModifiedWhisperEncoder.from_pretrained( - config.audio_model_id) - else: - self.audio_tower = ModifiedWhisperEncoder(config.audio_config) + self.secondary_weights.append( + DefaultModelLoader.Source( + model_or_path=config.audio_model_id, + revision=None, + prefix="audio_tower.", + )) self.multi_modal_projector = UltravoxProjector(config) self.language_model = init_vllm_registered_model( config.text_config, cache_config, quant_config) + if config.text_model_id is not None: + self.secondary_weights.append( + DefaultModelLoader.Source(model_or_path=config.text_model_id, + revision=None, + prefix="language_model.")) def _audio_features_to_embeddings( self, input_features: torch.Tensor) -> torch.Tensor: @@ -466,6 +476,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components weights_group = group_weights_with_prefix(weights) + # load audio tower weights + audio_tower_weights = weights_group["audio_tower"] + audio_tower_params_dict = dict( + self.audio_tower.named_parameters( + prefix=self.audio_tower.base_model_prefix)) + for name, loaded_weight in audio_tower_weights: + if name in audio_tower_params_dict: + param = audio_tower_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + # load projector weights projector_weights = weights_group["multi_modal_projector"] projector_params_dict = dict( From 8ff7ced996d5dc8b682913471f36c9fefb0e843f Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 24 Sep 2024 01:36:46 -0600 Subject: [PATCH 13/50] [Model] Expose Phi3v num_crops as a mm_processor_kwarg (#8658) Signed-off-by: Alex-Brooks Co-authored-by: Cyrus Leung Co-authored-by: DarkLight1337 --- examples/offline_inference_vision_language.py | 14 ++ ...e_inference_vision_language_multi_image.py | 13 ++ .../vision_language/test_phi3v.py | 186 +++++++++++++++++- vllm/model_executor/models/phi3v.py | 31 ++- 4 files changed, 230 insertions(+), 14 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index c1129316a6e30..6675aa0109a68 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -83,10 +83,24 @@ def run_phi3v(question, modality): # In this example, we override max_num_seqs to 5 while # keeping the original context length of 128k. + + # num_crops is an override kwarg to the multimodal image processor; + # For some models, e.g., Phi-3.5-vision-instruct, it is recommended + # to use 16 for single frame scenarios, and 4 for multi-frame. + # + # Generally speaking, a larger value for num_crops results in more + # tokens per image instance, because it may scale the image more in + # the image preprocessing. Some references in the model docs and the + # formula for image tokens after the preprocessing + # transform can be found below. + # + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194 llm = LLM( model="microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True, max_num_seqs=5, + mm_processor_kwargs={"num_crops": 16}, ) stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 92ab4f42baa80..8c5f1a7b7af08 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -67,11 +67,24 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData: def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: + # num_crops is an override kwarg to the multimodal image processor; + # For some models, e.g., Phi-3.5-vision-instruct, it is recommended + # to use 16 for single frame scenarios, and 4 for multi-frame. + # + # Generally speaking, a larger value for num_crops results in more + # tokens per image instance, because it may scale the image more in + # the image preprocessing. Some references in the model docs and the + # formula for image tokens after the preprocessing + # transform can be found below. + # + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194 llm = LLM( model="microsoft/Phi-3.5-vision-instruct", trust_remote_code=True, max_model_len=4096, limit_mm_per_prompt={"image": len(image_urls)}, + mm_processor_kwargs={"num_crops": 4}, ) placeholders = "\n".join(f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1)) diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py index e248151c40a60..eba0a1a1bce42 100644 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -1,16 +1,21 @@ import os import re -from typing import List, Optional, Tuple, Type +from typing import Callable, List, Optional, Tuple, Type import pytest -from transformers import AutoTokenizer +import torch +from transformers import AutoImageProcessor, AutoTokenizer +from vllm.inputs import InputContext, LLMInputs +from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID +from vllm.multimodal import MultiModalRegistry from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs from vllm.utils import is_cpu, is_hip -from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner -from ...utils import check_logprobs_close +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) +from ...utils import build_model_context, check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -71,7 +76,7 @@ def run_test( All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects + For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. @@ -230,3 +235,174 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, mm_limit=2, tensor_parallel_size=1, ) + + +### Fast tests for correctness in processor_kwarg override handling + + +# Wrap lazy imports to avoid initializing CUDA during test collection +@pytest.fixture() +def input_processor_for_phi3v(): + from vllm.model_executor.models.phi3v import input_processor_for_phi3v + return input_processor_for_phi3v + + +@pytest.fixture() +def dummy_data_for_phi3v(): + from vllm.model_executor.models.phi3v import dummy_data_for_phi3v + return dummy_data_for_phi3v + + +@pytest.fixture() +def get_max_phi3v_image_tokens(): + from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens + return get_max_phi3v_image_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops", [4, 16, None]) +def test_input_mapper_override(model: str, image_assets: _ImageAssets, + num_crops: Optional[int]): + """Ensure that the [default] input mapper handles num_crops properly.""" + # We pass the processor kwargs here since for this model, we fall back to + # the default mapper; this will fall back to the HF mapper and forward + # mm_processor_kwargs to it. + mm_processor_kwargs = { + "num_crops": num_crops + } if num_crops is not None else {} + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + ) + + hf_processor = AutoImageProcessor.from_pretrained(model, + trust_remote_code=True, + **mm_processor_kwargs) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + image = image_assets[0].pil_image + hf_result = hf_processor.preprocess( + image, + return_tensors="pt", + ) + + vllm_result = mm_registry.map_input( + ctx.model_config, + {"image": image}, + ) + + assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"]) + assert torch.all( + hf_result["num_img_tokens"] == vllm_result["num_img_tokens"]) + + # For pixel values, the second axis should be the num_crops + 1 + # for the rescaled original image. The default value in VLLM falls + # back to the HF config, which is why we compare to the processor num_crops + assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"]) + assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1 + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,expected_max_tokens", [ + (4, 781), + (16, 2653), +]) +def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str, + num_crops: int, expected_max_tokens: int): + """Ensure get_max_phi3v_image_tokens handles num_crops properly.""" + # NOTE: mm_processor_kwargs on the context in this test is unused, since + # this is testing the mapper directly. In practice, the processor kwargs + # are wrapped in a closure when calling the max tokens func. We explicitly + # do NOT use the mm_processor_kwargs in the model context here to ensure + # that the max image tokens implementation is referencing a mix of the + # kwargs to the function and the original mm_processor_kwargs in case + # values are somehow updated and end up in a bad state. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + actual_max_tokens = get_max_phi3v_image_tokens( + InputContext(ctx.model_config), + num_crops=num_crops, + ) + + assert expected_max_tokens == actual_max_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [ + (4, 781, 1), + (4, 781, 2), + (16, 2653, 1), + (16, 2653, 2), +]) +def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str, + num_crops: int, toks_per_img: int, num_imgs: int): + """Ensure dummy_data_for_phi3v handles num_crops properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the dummy data func. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + sequence_data, _, = dummy_data_for_phi3v( + ctx=ctx, + seq_len=8192, # Should be bigger than num_imgs * toks_per_img + mm_counts={"image": num_imgs}, + num_crops=num_crops, + ) + # Ensure we have the right number of placeholders per num_crops size + img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID) + assert img_tok_count == toks_per_img * num_imgs + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [ + (4, 757, 1), + (4, 757, 2), + (16, 1921, 1), + (16, 1921, 2), +]) +def test_input_processor_override(input_processor_for_phi3v: Callable, + image_assets: _ImageAssets, model: str, + num_crops: int, expected_toks_per_img: int, + num_imgs: int): + """Ensure input_processor_for_phi3v handles num_crops properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the custom input processor. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(model) + # Build the image str / prompt based on the number of images we pass + img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) + prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" + images = [image_assets[0].pil_image] * num_imgs + + llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt), + prompt=prompt, + multi_modal_data={"image": images}) + + proc_llm_inputs = input_processor_for_phi3v( + ctx=ctx, + llm_inputs=llm_inputs, + num_crops=num_crops, + ) + + # Ensure we have the right number of placeholders per num_crops size + img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) + assert img_tok_count == expected_toks_per_img * num_imgs diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 6f17f571ccaea..245381518a7f8 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -307,7 +307,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90 -def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16): +def _calc_hd_transform_size(*, width: int, height: int, hd_num: int): transposed = False if width < height: width, height = height, width @@ -337,8 +337,10 @@ def get_phi3v_image_feature_size( *, input_height: int, input_width: int, + num_crops: int, ) -> int: - num_crops = hf_config.get("num_crops", 16) + if num_crops is None: + num_crops = hf_config.get("num_crops", 16) new_width, new_height = _calc_hd_transform_size(width=input_width, height=input_height, hd_num=num_crops) @@ -347,20 +349,26 @@ def get_phi3v_image_feature_size( + (new_height // 336 + 1) * 12 -def get_max_phi3v_image_tokens(ctx: InputContext): +def get_max_phi3v_image_tokens(ctx: InputContext, + *, + num_crops: Optional[int] = None): return get_phi3v_image_feature_size( ctx.get_hf_image_processor_config(), input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + num_crops=num_crops, ) -def dummy_data_for_phi3v(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): +def dummy_data_for_phi3v(ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + *, + num_crops: Optional[int] = None): num_images = mm_counts["image"] - image_feature_size = get_max_phi3v_image_tokens(ctx) + image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops) seq_data = dummy_seq_data_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, @@ -398,7 +406,10 @@ def _get_image_placeholder_token_ids(model_config: ModelConfig, return image_placeholder_token_ids -def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_phi3v(ctx: InputContext, + llm_inputs: LLMInputs, + *, + num_crops: Optional[int] = None): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs @@ -412,7 +423,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): image_feature_size = [ get_phi3v_image_feature_size(hf_config, input_width=w, - input_height=h) + input_height=h, + num_crops=num_crops) ] image_data = [image_data] elif is_list_of(image_data, Image.Image): @@ -422,7 +434,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): image_feature_size.append( get_phi3v_image_feature_size(hf_config, input_width=w, - input_height=h)) + input_height=h, + num_crops=num_crops)) elif isinstance(image_data, torch.Tensor): num_images, image_feature_size, hidden_size = image_data.shape elif is_list_of(image_data, torch.Tensor): From cc4325b66ac49e403ed9e1a8c38156a5324e1174 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Tue, 24 Sep 2024 01:08:14 -0700 Subject: [PATCH 14/50] [Bugfix] Fix potentially unsafe custom allreduce synchronization (#8558) --- csrc/custom_all_reduce.cuh | 128 +++++++++++++++++++-------------- csrc/custom_all_reduce_test.cu | 14 ++-- 2 files changed, 83 insertions(+), 59 deletions(-) diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 1ed49b8aa9cae..632b579c55afa 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -23,17 +24,23 @@ namespace vllm { -constexpr int kMaxBlocks = 64; -// note: we don't want to use atomics for signals because peer atomics are no -// supported on PCIe links +constexpr int kMaxBlocks = 36; +// Counter may overflow, but it's fine since unsigned int overflow is +// well-defined behavior. +using FlagType = uint32_t; struct Signal { - alignas(128) uint32_t start[kMaxBlocks][8]; - alignas(128) uint32_t end[kMaxBlocks][8]; + alignas(128) FlagType self_counter[kMaxBlocks][8]; + // Two sets of peer counters are needed for two syncs. The reason is that + // it's possible for peer GPU block to arrive at the second sync point while + // the current GPU block haven't passed the first sync point. Thus, peer GPU + // may write counter+1 while current GPU is busy waiting for counter. We use + // alternating counter array to avoid this possibility. + alignas(128) FlagType peer_counter[2][kMaxBlocks][8]; }; struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; -struct __align__(16) RankSignals { volatile Signal* signals[8]; }; +struct __align__(16) RankSignals { Signal* signals[8]; }; // like std::array, but aligned template @@ -123,47 +130,60 @@ DINLINE O downcast(array_t val) { } } -// This function is meant to be used as the first synchronization in the all -// reduce kernel. Thus, it doesn't need to make any visibility guarantees for -// prior memory accesses. Note: volatile writes will not be reordered against -// other volatile writes. -template -DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, - int rank) { - if (threadIdx.x < ngpus) { - // reset flag for next time - self_sg->end[blockIdx.x][threadIdx.x] = 0; - // simultaneously write to the corresponding flag of all ranks. - // Latency = 1 p2p write - sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; - // wait until we got true from all ranks - while (!self_sg->start[blockIdx.x][threadIdx.x]); - } - __syncthreads(); +static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { + asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), + "l"(flag_addr)); +} + +static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) { + FlagType flag; + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); + return flag; +} + +static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) { + asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +} + +static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) { + FlagType flag; + asm volatile("ld.volatile.global.u32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); + return flag; } -// This function is meant to be used as the second or the final synchronization -// barrier in the all reduce kernel. If it's the final synchronization barrier, -// we don't need to make any visibility guarantees for prior memory accesses. -template -DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, - int rank) { - __syncthreads(); - // eliminate the case that prior writes are not visible after signals become - // visible. Note that I did not managed to make this happen through a lot of - // testing. Might be the case that hardware provides stronger guarantee than - // the memory model. - if constexpr (!final_sync) __threadfence_system(); +// is_start: whether this is the very first synchronization barrier. +// need_fence: whether a memory fence is needed. If true, a release-acquire +// semantic is used to enforce memory access order before and after this +// barrier. +template +DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, + int rank) { + if constexpr (!is_start) __syncthreads(); + static_assert( + !(is_start && need_fence)); // Start barrier shouldn't need fence. if (threadIdx.x < ngpus) { - // reset flag for next time - self_sg->start[blockIdx.x][threadIdx.x] = 0; - // simultaneously write to the corresponding flag of all ranks. - // Latency = 1 p2p write - sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; - // wait until we got true from all ranks - while (!self_sg->end[blockIdx.x][threadIdx.x]); + // Increment the counter. Technically we only need one counter, but we use + // multiple per block to eliminate the need to share the counter via smem. + auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1; + // Write the expected counter value to peer and wait for correct value from + // peer. + auto peer_counter_ptr = + &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank]; + auto self_counter_ptr = + &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; + if constexpr (need_fence) { + st_flag_release(peer_counter_ptr, val); + while (ld_flag_acquire(self_counter_ptr) != val); + } else { + st_flag_volatile(peer_counter_ptr, val); + while (ld_flag_volatile(self_counter_ptr) != val); + } } - if constexpr (!final_sync) __syncthreads(); + if constexpr (is_start || need_fence) __syncthreads(); } template @@ -178,33 +198,31 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData* _dp, RankSignals sg, - volatile Signal* self_sg, T* __restrict__ result, - int rank, int size) { + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg, + T* __restrict__ result, int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; // note: we don't reorder the address so the accumulation order is the same // for all ranks, ensuring bitwise identical results auto dp = *_dp; - start_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); } - end_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); } template -DINLINE P* get_tmp_buf(volatile Signal* sg) { +DINLINE P* get_tmp_buf(Signal* sg) { return (P*)(((Signal*)sg) + 1); } template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData* _dp, RankSignals sg, - volatile Signal* self_sg, T* __restrict__ result, - int rank, int size) { + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg, + T* __restrict__ result, int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; using P = typename packed_t::P; @@ -222,12 +240,12 @@ __global__ void __launch_bounds__(512, 1) tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; - start_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); // stage 1: reduce scatter for (int idx = start + tid; idx < end; idx += stride) { tmp_out[idx - start] = packed_reduce(ptrs, idx); } - end_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); // stage 2: allgather. Note: it's important to match the tid between // the two stages, because visibility across devices is only guaranteed @@ -437,6 +455,8 @@ class CustomAllreduce { #define KL(ngpus, name) \ name<<>>(ptrs, sg_, self_sg_, output, \ rank_, size); + // TODO(hanzhi713): Threshold is different for A100 and H100. + // Add per device threshold. #define REDUCE_CASE(ngpus) \ case ngpus: { \ if (world_size_ == 2) { \ diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index f7868233076cd..c8b5d0a013f63 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -1,15 +1,15 @@ /** * This is a standalone test for custom allreduce. * To compile, make sure you have MPI and NCCL installed in your system. - * export MPI_HOME=XXX + * export MPI_HOME=xxx * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o - * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi + * custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi * * Warning: this C++ test is not designed to be very readable and was used * during the rapid prototyping process. * * To run: - * mpirun -np 8 ./custom_all_reduce_test + * mpirun --allow-run-as-root -np 8 ./custom_all_reduce_test */ #include #include @@ -302,15 +302,19 @@ int main(int argc, char** argv) { bool performance_test = true; cudaProfilerStart(); - // for (int threads : {256, 512}) { + // Uncomment to scan through different block size configs. + // for (int threads : {256, 512, 1024}) { // for (int block_limit = 16; block_limit < 112; block_limit += 4) { - // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); + // run(myRank, nRanks, comm, threads, block_limit, 1024 * 1024, + // performance_test); // } // } + // Scan through different sizes to test performance. for (int sz = 512; sz <= (8 << 20); sz *= 2) { run(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test); } cudaProfilerStop(); + MPICHECK(MPI_Finalize()); return EXIT_SUCCESS; } From a928ded99519f803d4cf6389df6acc707239a5cc Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 24 Sep 2024 18:31:42 +0200 Subject: [PATCH 15/50] [Kernel] Split Marlin MoE kernels into multiple files (#8661) Co-authored-by: mgoin --- CMakeLists.txt | 5 + csrc/moe/marlin_kernels/marlin_moe_kernel.h | 1425 ++++++++++++++++ .../marlin_kernels/marlin_moe_kernel_ku4b8.cu | 29 + .../marlin_kernels/marlin_moe_kernel_ku4b8.h | 20 + .../marlin_moe_kernel_ku8b128.cu | 29 + .../marlin_moe_kernel_ku8b128.h | 18 + csrc/moe/marlin_moe_ops.cu | 1453 +---------------- 7 files changed, 1552 insertions(+), 1427 deletions(-) create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel.h create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h diff --git a/CMakeLists.txt b/CMakeLists.txt index a05b53cba43f5..b2fa72d4775c4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -316,6 +316,11 @@ set(VLLM_MOE_EXT_SRC if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC + "csrc/moe/marlin_kernels/marlin_moe_kernel.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" "csrc/moe/marlin_moe_ops.cu") endif() diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h new file mode 100644 index 0000000000000..0bd3017226c94 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel.h @@ -0,0 +1,1425 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include + +#include "core/scalar_type.hpp" + +namespace marlin_moe { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline FragB dequant(int q); + +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +template <> +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 +// Reference: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float* c, FragS& s) { + __half* s_ptr = reinterpret_cast<__half*>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + +// Same as above, but for act_order (each K is multiplied individually) +__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, + FragS& frag_s_3, FragS& frag_s_4, int i) { + __half2 s_val_1_2; + s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; + + __half2 s_val_3_4; + s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__device__ inline void MarlinMoESingle( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block // current m block to start kernel computation from +) { + static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + constexpr int pack_factor = 32 / w_type.size_bits(); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + sorted_ids += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + constexpr int sorted_sh_stride = threads; + constexpr int sorted_gl_stride = threads; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + int shs_size; + if constexpr (has_act_order) + shs_size = sh_max_num_groups * s_sh_stride + threads; + else + shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); + int* sh_sorted = (int*)(sh_s + shs_size); + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_sh_wr_delta * i + a_sh_wr; + int row = a_idx / a_gl_rd_delta_o; + if (row >= prob_m) { + a_sh_wr_pred[i] = false; + } else { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int sorted_row = + replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; + int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + if (sorted_row < tot_m * (replicate_input ? 1 : topk) && + new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], + a_sh_wr_pred[i]); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // TODO we are currently hitting illegal memory accesses when fetching + // sorted_ids to shared data: fix this + auto fetch_sorted_ids_to_shared = [&]() { + const int mpt = ceildiv(prob_m, threads); + for (int i = 0; i < mpt; i++) { + if ((i * sorted_gl_stride) + threadIdx.x < prob_m) { + sh_sorted[(i * sorted_sh_stride) + threadIdx.x] = + sorted_ids[(i * sorted_gl_stride) + threadIdx.x]; + } + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + FragB frag_b0 = dequant(b_quant_0); + FragB frag_b1 = dequant(b_quant_1); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int sorted_row = sorted_ids[c_idx / c_gl_stride]; + int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], + sorted_row < tot_m * topk && + (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk))); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half*>(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int row = sorted_ids[c_idx / c_gl_stride]; + if (row < tot_m * topk) { + int new_idx = row * c_gl_stride + c_idx % c_gl_stride; + C[new_idx] = c; + } + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { + res = __hmul2(res, s[0]); + } + + ((half2*)sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + int row = sorted_ids[c_gl_wr / c_gl_stride]; + if (row < tot_m * topk) { + int off = row * c_gl_stride + c_gl_wr % c_gl_stride; + if (!apply_weights) { + C[off] = sh[c_sh_rd]; + } else { + __half* ctrg = reinterpret_cast<__half*>(&C[off]); + __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); + for (int j = 0; j < 8; ++j) { + ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); + } + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + // TODO re-enable after fixing this function + // fetch_sorted_ids_to_shared(); + // __syncthreads(); + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } else { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + start_pipes(); + } + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks +) { + int m_block_ctr = current_m_block; + + const int* sorted_ids_expert = + sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; + int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; + if (tot_its == 0) { + return; + } + int tot_m_blocks = ceildiv(tot_its, 16); + int pad = 16 * tot_m_blocks - tot_its; + + if (m_block_ctr >= tot_m_blocks) { + return; + } + + int max_block = tot_m_blocks - m_block_ctr; + prob_m = tot_its - 16 * m_block_ctr; + + int par = 1; + if (max_block > cfg_max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * cfg_max_m_blocks) * par; + m_block_ctr += cfg_max_m_blocks * (par - 1); + max_block = cfg_max_m_blocks; + } + + if (max_block == 1) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 2) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 3) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } +} + +#else + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks + +) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +const int USER_THREADS = + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory +// const int SHARED_MEM = +// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ + GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + cfg_max_m_blocks); \ + } + +#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu new file mode 100644 index 0000000000000..cbafd9ffe7474 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu @@ -0,0 +1,29 @@ +#include "marlin_moe_kernel_ku4b8.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4b8( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, + int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, + int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, + bool replicate_input, bool apply_weights, int m_block, int max_par, + int cfg_max_m_blocks) { + if (false) { + } + GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h new file mode 100644 index 0000000000000..9eacb42c115f0 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h @@ -0,0 +1,20 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4b8( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, + int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, + int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, + bool replicate_input, bool apply_weights, int m_block, int max_par, + int cfg_max_m_blocks); + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu new file mode 100644 index 0000000000000..c46712474f715 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu @@ -0,0 +1,29 @@ +#include "marlin_moe_kernel_ku8b128.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku8b128( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, + int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, + int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, + bool replicate_input, bool apply_weights, int m_block, int max_par, + int cfg_max_m_blocks) { + if (false) { + } + GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h new file mode 100644 index 0000000000000..7cd9acafb3b80 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h @@ -0,0 +1,18 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +bool call_marlin_moe_kernel_ku8b128( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, + int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, + int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, + bool replicate_input, bool apply_weights, int m_block, int max_par, + int cfg_max_m_blocks); + +} diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 293a6fad72c2f..dfe0437414013 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -26,6 +26,8 @@ #include #include "core/scalar_type.hpp" +#include "marlin_kernels/marlin_moe_kernel_ku4b8.h" +#include "marlin_kernels/marlin_moe_kernel_ku8b128.h" template inline std::string str(T x) { @@ -34,230 +36,8 @@ inline std::string str(T x) { namespace marlin_moe { -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -template -__device__ inline FragB dequant(int q); - -// Efficiently dequantize 4bit values packed in an int32 value into a full -// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, -// with some small changes: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -template <> -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 -// Reference: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -template <> -__device__ inline FragB dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -// Given 2 floats multiply by 2 scales (halves) -__device__ inline void scale_float(float* c, FragS& s) { - __half* s_ptr = reinterpret_cast<__half*>(&s); - c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); -} - -// Same as above, but for act_order (each K is multiplied individually) -__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, - FragS& frag_s_3, FragS& frag_s_4, int i) { - __half2 s_val_1_2; - s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; - - __half2 s_val_3_4; - s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; - - frag_b[0] = __hmul2(frag_b[0], s_val_1_2); - frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, @@ -335,1106 +115,6 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, __syncthreads(); } -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__device__ inline void MarlinMoESingle( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block // current m block to start kernel computation from -) { - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); - constexpr int pack_factor = 32 / w_type.size_bits(); - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - } - } - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - sorted_ids += 16 * thread_m_blocks; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = - !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - constexpr int sorted_sh_stride = threads; - constexpr int sorted_gl_stride = threads; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd; - if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - } - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - int s_sh_rd; - if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - int shs_size; - if constexpr (has_act_order) - shs_size = sh_max_num_groups * s_sh_stride + threads; - else - shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx + (stages * g_idx_stage); - int* sh_sorted = (int*)(sh_s + shs_size); - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_sh_wr_delta * i + a_sh_wr; - int row = a_idx / a_gl_rd_delta_o; - if (row >= prob_m) { - a_sh_wr_pred[i] = false; - } else { - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - } - } - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; - int row = a_idx / a_gl_stride; - int sorted_row = - replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; - int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; - if (sorted_row < tot_m * (replicate_input ? 1 : topk) && - new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { - cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], - a_sh_wr_pred[i]); - } - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - B_ptr[i] += b_gl_rd_delta_o; - } - - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); - - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // TODO we are currently hitting illegal memory accesses when fetching - // sorted_ids to shared data: fix this - auto fetch_sorted_ids_to_shared = [&]() { - const int mpt = ceildiv(prob_m, threads); - for (int i = 0; i < mpt; i++) { - if ((i * sorted_gl_stride) + threadIdx.x < prob_m) { - sh_sorted[(i * sorted_sh_stride) + threadIdx.x] = - sorted_ids[(i * sorted_gl_stride) + threadIdx.x]; - } - } - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - if constexpr (!has_act_order) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - int warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; - - int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; - - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } - - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); - } - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { - b_quant_0 = frag_b_quant[k % 2][0][j]; - b_quant_1 = b_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - } - - FragB frag_b0 = dequant(b_quant_0); - FragB frag_b1 = dequant(b_quant_1); - - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int sorted_row = sorted_ids[c_idx / c_gl_stride]; - int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], - sorted_row < tot_m * topk && - (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk))); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int row = sorted_ids[c_idx / c_gl_stride]; - if (row < tot_m * topk) { - int new_idx = row * c_gl_stride + c_idx % c_gl_stride; - C[new_idx] = c; - } - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { - res = __hmul2(res, s[0]); - } - - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - int row = sorted_ids[c_gl_wr / c_gl_stride]; - if (row < tot_m * topk) { - int off = row * c_gl_stride + c_gl_wr % c_gl_stride; - if (!apply_weights) { - C[off] = sh[c_sh_rd]; - } else { - __half* ctrg = reinterpret_cast<__half*>(&C[off]); - __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); - for (int j = 0; j < 8; ++j) { - ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); - } - } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - // TODO re-enable after fixing this function - // fetch_sorted_ids_to_shared(); - // __syncthreads(); - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - // For 4-bit per-column scales, we only fetch them here in the - // final step before write-out - if (last) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } - } - } - - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } - start_pipes(); - } - } - } -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks -) { - int m_block_ctr = current_m_block; - - const int* sorted_ids_expert = - sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; - int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; - if (tot_its == 0) { - return; - } - int tot_m_blocks = ceildiv(tot_its, 16); - int pad = 16 * tot_m_blocks - tot_its; - - if (m_block_ctr >= tot_m_blocks) { - return; - } - - int max_block = tot_m_blocks - m_block_ctr; - prob_m = tot_its - 16 * m_block_ctr; - - int par = 1; - if (max_block > cfg_max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * cfg_max_m_blocks) * par; - m_block_ctr += cfg_max_m_blocks * (par - 1); - max_block = cfg_max_m_blocks; - } - - if (max_block == 1) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 2) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 3) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } -} - #else __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, @@ -1454,81 +134,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, return; } -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks - -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - #endif -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory -// const int SHARED_MEM = -// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - exec_cfg.max_m_blocks); \ - } - typedef struct { int thread_k; int thread_n; @@ -1703,25 +310,27 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) - -void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, - const void* sorted_ids, const void* topk_weights, - const void* topk_ids, const void* s, const void* g_idx, - const void* perm, void* a_tmp, void* expert_offsets, - int prob_m, int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, int num_groups, int group_size, - int num_experts, int topk, int moe_block_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, - int sms, int max_par, bool replicate_input, - bool apply_weights) { +#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ + else if (KERNEL_FUNCTION(q_type, thread_n_blocks, thread_k_blocks, \ + has_act_order, group_blocks, num_threads, blocks, \ + max_shared_mem, stream, A_ptr, B_ptr, C_ptr, \ + sorted_ids_ptr, topk_weights_ptr, s_ptr, g_idx_ptr, \ + expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, \ + locks, replicate_input, apply_weights, m_block, \ + max_par, exec_cfg.max_m_blocks)) { \ + } + +void marlin_mm_moe(const void* A, const void* B, void* C, + const void* sorted_ids, const void* topk_weights, + const void* topk_ids, const void* s, const void* g_idx, + const void* perm, void* a_tmp, void* expert_offsets, + int prob_m, int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, int num_groups, int group_size, + int num_experts, int topk, int moe_block_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, int sms, + int max_par, bool replicate_input, bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1845,26 +454,16 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, int tot_m_blocks = ceildiv(tot_m, 16); for (int m_block = 0; m_block < tot_m_blocks; m_block += 4 * exec_cfg.max_m_blocks) { - // make it max possible value - int thread_m_blocks = exec_cfg.max_m_blocks; - if (false) { } - CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) - CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) - CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) - CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) - CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) - CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) - CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) - CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + ", has_act_order = " + str(has_act_order) + ", num_groups = " + str(num_groups) + ", group_size = " + str(group_size) + - ", thread_m_blocks = " + str(thread_m_blocks) + ", thread_n_blocks = " + str(thread_n_blocks) + ", thread_k_blocks = " + str(thread_k_blocks)); } @@ -1943,7 +542,7 @@ torch::Tensor marlin_gemm_moe( } } - marlin_moe::marlin_mm_moe_f16i4( + marlin_moe::marlin_mm_moe( a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), From 2529d09b5a4a124a316b6976e7d782f54e0bddde Mon Sep 17 00:00:00 2001 From: Andy <37781802+aandyw@users.noreply.github.com> Date: Tue, 24 Sep 2024 12:44:11 -0400 Subject: [PATCH 16/50] [Frontend] Batch inference for llm.chat() API (#8648) Co-authored-by: Cyrus Leung Co-authored-by: Cyrus Leung Co-authored-by: Roger Wang Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- examples/offline_inference_chat.py | 27 +++++++++ tests/entrypoints/llm/test_generate.py | 35 +++++++++++ vllm/entrypoints/llm.py | 82 +++++++++++++++----------- 3 files changed, 111 insertions(+), 33 deletions(-) diff --git a/examples/offline_inference_chat.py b/examples/offline_inference_chat.py index c2020724c72fe..8814f4d7bef0d 100644 --- a/examples/offline_inference_chat.py +++ b/examples/offline_inference_chat.py @@ -39,6 +39,33 @@ def print_outputs(outputs): use_tqdm=False) print_outputs(outputs) +# You can run batch inference with llm.chat API +conversation = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Write an essay about the importance of higher education.", + }, +] +conversations = [conversation for _ in range(10)] + +# We turn on tqdm progress bar to verify it's indeed running batch inference +outputs = llm.chat(messages=conversations, + sampling_params=sampling_params, + use_tqdm=True) +print_outputs(outputs) + # A chat template can be optionally supplied. # If not, the model will use its default chat template. diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index ef34bebbb0f8c..cd989225e2483 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -162,6 +162,41 @@ def test_chat(): assert len(outputs) == 1 +def test_multi_chat(): + + llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") + + prompt1 = "Explain the concept of entropy." + prompt2 = "Explain what among us is." + + conversation1 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt1 + }, + ] + + conversation2 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt2 + }, + ] + + messages = [conversation1, conversation2] + + outputs = llm.chat(messages) + assert len(outputs) == 2 + + @pytest.mark.parametrize("image_urls", [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) def test_chat_multi_image(image_urls: List[str]): diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ca80dedd29ebd..cd10eda8c212c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -485,7 +485,8 @@ def beam_search( def chat( self, - messages: List[ChatCompletionMessageParam], + messages: Union[List[ChatCompletionMessageParam], + List[List[ChatCompletionMessageParam]]], sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None, use_tqdm: bool = True, @@ -505,8 +506,9 @@ def chat( to the OpenAI API. Args: - messages: A single conversation represented as a list of messages. - Each message is a dictionary with 'role' and 'content' keys. + messages: A list of conversations or a single conversation. + - Each conversation is represented as a list of messages. + - Each message is a dictionary with 'role' and 'content' keys. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it @@ -523,42 +525,56 @@ def chat( A list of ``RequestOutput`` objects containing the generated responses in the same order as the input messages. """ + list_of_messages: List[List[ChatCompletionMessageParam]] - tokenizer = self.get_tokenizer() - model_config = self.llm_engine.get_model_config() - - conversation, mm_data = parse_chat_messages(messages, model_config, - tokenizer) - - prompt: Union[str, List[int]] - if isinstance(tokenizer, MistralTokenizer): - prompt = apply_mistral_chat_template( - tokenizer, - messages=messages, - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - tools=tools, - ) + # Handle multi and single conversations + if is_list_of(messages, list): + # messages is List[List[...]] + list_of_messages = messages else: - prompt = apply_hf_chat_template( - tokenizer, - conversation=conversation, - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - tools=tools, - ) + # messages is List[...] + list_of_messages = [messages] + + prompts: List[Union[TokensPrompt, TextPrompt]] = [] + + for msgs in list_of_messages: + tokenizer = self.get_tokenizer() + model_config = self.llm_engine.get_model_config() + + conversation, mm_data = parse_chat_messages( + msgs, model_config, tokenizer) + + prompt_data: Union[str, List[int]] + if isinstance(tokenizer, MistralTokenizer): + prompt_data = apply_mistral_chat_template( + tokenizer, + messages=msgs, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + tools=tools, + ) + else: + prompt_data = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + tools=tools, + ) + + prompt: Union[TokensPrompt, TextPrompt] + if is_list_of(prompt_data, int): + prompt = TokensPrompt(prompt_token_ids=prompt_data) + else: + prompt = TextPrompt(prompt=prompt_data) - inputs: PromptInputs - if is_list_of(prompt, int): - inputs = TokensPrompt(prompt_token_ids=prompt) - else: - inputs = TextPrompt(prompt=prompt) + if mm_data is not None: + prompt["multi_modal_data"] = mm_data - if mm_data is not None: - inputs["multi_modal_data"] = mm_data + prompts.append(prompt) return self.generate( - inputs, + prompts, sampling_params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, From 72fc97a0f100b92f1ff6c6a16e27d12f1c7569aa Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 24 Sep 2024 14:33:21 -0400 Subject: [PATCH 17/50] [Bugfix] Fix torch dynamo fixes caused by `replace_parameters` (#8748) --- .../layers/quantization/utils/layer_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py index c38bd8955f457..edce6d19b6c49 100644 --- a/vllm/model_executor/layers/quantization/utils/layer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -21,13 +21,17 @@ def replace_parameter(mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter]) -> None: old = getattr(mod, name) - if old.dtype == new.dtype and \ + if type(old) is type(new) and old.dtype == new.dtype and \ old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): # If we can just update in-place to avoid re-registering # can be faster if the underlying storage is the same update_tensor_inplace(old, new) else: - # Fallback re-register parameter + # Fallback re-register parameter, convert to Parameter if necessary + # this not only ensures we don't register a tensor as a parameter, but + # also ensures that all parameter subclasses get re-registered as + # parameters for `torch.compile` compatibility if not isinstance(new, torch.nn.Parameter): - new = torch.nn.Parameter(new) - mod.register_parameter(name, torch.nn.Parameter(new)) + new = torch.nn.Parameter(new, requires_grad=False) + mod.register_parameter(name, + torch.nn.Parameter(new, requires_grad=False)) From 2467b642dd9bde32a334fe5967efd78a53aa49da Mon Sep 17 00:00:00 2001 From: Daniele <36171005+dtrifiro@users.noreply.github.com> Date: Tue, 24 Sep 2024 21:38:12 +0200 Subject: [PATCH 18/50] [CI/Build] fix setuptools-scm usage (#8771) --- .gitignore | 7 ++----- pyproject.toml | 3 --- setup.py | 7 +++++-- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 43eb89cacc0a5..abeaf0a82e303 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ -# vllm commit id, generated by setup.py -vllm/commit_id.py +# version file generated by setuptools-scm +/vllm/_version.py # vllm-flash-attn built from source vllm/vllm_flash_attn/ @@ -196,8 +196,5 @@ _build/ *_hip* hip_compat.h -# version file generated by setuptools-scm -/vllm/_version.py - # Benchmark dataset benchmarks/*.json diff --git a/pyproject.toml b/pyproject.toml index 4e1841484420a..c9057b061aad9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,9 +51,6 @@ ignore = [ "UP032", ] -[tool.setuptools_scm] -version_file = "vllm/_version.py" - [tool.mypy] python_version = "3.8" diff --git a/setup.py b/setup.py index 85a2852136eaa..8ef759f5245fc 100644 --- a/setup.py +++ b/setup.py @@ -354,12 +354,15 @@ def get_path(*filepath) -> str: def get_vllm_version() -> str: - version = get_version() + version = get_version( + write_to="vllm/_version.py", # TODO: move this to pyproject.toml + ) + sep = "+" if "+" not in version else "." # dev versions might contain + if _no_device(): if envs.VLLM_TARGET_DEVICE == "empty": - version += "+empty" + version += f"{sep}empty" elif _is_cuda(): cuda_version = str(get_nvcc_cuda_version()) if cuda_version != MAIN_CUDA_VERSION: From 1e7d5c01f5c35424eede1bbe6f723dd8781120f0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 24 Sep 2024 15:48:39 -0700 Subject: [PATCH 19/50] [misc] soft drop beam search (#8763) --- vllm/envs.py | 5 +++++ vllm/sampling_params.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/vllm/envs.py b/vllm/envs.py index 43c7aa8af85b2..705d858e71a66 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -62,6 +62,7 @@ VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False + VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: bool = False def get_default_cache_root(): @@ -195,6 +196,10 @@ def get_default_config_root(): lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), + # If set, allowing the use of deprecated beam search implementation + "VLLM_ALLOW_DEPRECATED_BEAM_SEARCH": + lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BEAM_SEARCH", "0") == "1", + # Internal flag to enable Dynamo graph capture "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 86e80ae5e224d..f9ba4b4777e4d 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -8,6 +8,7 @@ import torch from typing_extensions import Annotated +import vllm.envs as envs from vllm.logger import init_logger logger = init_logger(__name__) @@ -260,6 +261,10 @@ def __post_init__(self) -> None: self._verify_args() if self.use_beam_search: + if not envs.VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: + raise ValueError( + "Using beam search as a sampling parameter is deprecated, and will be removed in the future release. Please use the `vllm.LLM.use_beam_search` method for dedicated beam search instead, or set the environment variable `VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1` to suppress this error. For more details, see https://github.com/vllm-project/vllm/issues/8306 ." # noqa + ) self._verify_beam_search() else: self._verify_non_beam_search() From 13f9f7a3d0373421ee9fd7498e450214e134aa6c Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 25 Sep 2024 08:08:55 +0800 Subject: [PATCH 20/50] [[Misc]Upgrade bitsandbytes to the latest version 0.44.0 (#8768) --- docs/source/quantization/bnb.rst | 2 +- examples/lora_with_quantization_inference.py | 26 +++++++--------- requirements-test.txt | 2 +- tests/quantization/test_bitsandbytes.py | 2 +- vllm/config.py | 30 ++++++++++++++----- .../layers/quantization/bitsandbytes.py | 8 ++--- vllm/model_executor/model_loader/loader.py | 8 ++--- 7 files changed, 44 insertions(+), 34 deletions(-) diff --git a/docs/source/quantization/bnb.rst b/docs/source/quantization/bnb.rst index aefb54a8acb65..682938cc63d48 100644 --- a/docs/source/quantization/bnb.rst +++ b/docs/source/quantization/bnb.rst @@ -11,7 +11,7 @@ Below are the steps to utilize BitsAndBytes with vLLM. .. code-block:: console - $ pip install bitsandbytes>=0.42.0 + $ pip install bitsandbytes>=0.44.0 vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint. diff --git a/examples/lora_with_quantization_inference.py b/examples/lora_with_quantization_inference.py index 3b2347c1115e1..0c454ea50f665 100644 --- a/examples/lora_with_quantization_inference.py +++ b/examples/lora_with_quantization_inference.py @@ -79,23 +79,17 @@ def initialize_engine(model: str, quantization: str, # It quantizes the model when loading, with some config info from the # LoRA adapter repo. So need to set the parameter of load_format and # qlora_adapter_name_or_path as below. - engine_args = EngineArgs( - model=model, - quantization=quantization, - qlora_adapter_name_or_path=lora_repo, - load_format="bitsandbytes", - enable_lora=True, - max_lora_rank=64, - # set it only in GPUs of limited memory - enforce_eager=True) + engine_args = EngineArgs(model=model, + quantization=quantization, + qlora_adapter_name_or_path=lora_repo, + load_format="bitsandbytes", + enable_lora=True, + max_lora_rank=64) else: - engine_args = EngineArgs( - model=model, - quantization=quantization, - enable_lora=True, - max_loras=4, - # set it only in GPUs of limited memory - enforce_eager=True) + engine_args = EngineArgs(model=model, + quantization=quantization, + enable_lora=True, + max_loras=4) return LLMEngine.from_engine_args(engine_args) diff --git a/requirements-test.txt b/requirements-test.txt index 10d463de27be5..9c6fadb88865a 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -30,5 +30,5 @@ datamodel_code_generator # required for minicpm3 test aiohttp # quantization -bitsandbytes==0.42.0 +bitsandbytes>=0.44.0 buildkite-test-collector==0.1.8 diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 36167cf95f589..ac2ebc622ba6f 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -107,7 +107,7 @@ def validate_generated_texts(hf_runner, quantization='bitsandbytes', load_format='bitsandbytes', tensor_parallel_size=vllm_tp_size, - enforce_eager=True, + enforce_eager=False, gpu_memory_utilization=0.8) as llm: vllm_outputs = llm.generate_greedy(prompts, 8) vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") diff --git a/vllm/config.py b/vllm/config.py index 8c65d99c44651..562564bbfa032 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -222,6 +222,7 @@ def __init__(self, self._verify_embedding_mode() self._verify_quantization() self._verify_cuda_graph() + self._verify_bnb_config() def _init_multimodal_config( self, limit_mm_per_prompt: Optional[Mapping[str, int]] @@ -337,6 +338,28 @@ def _verify_cuda_graph(self) -> None: self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_model_len) + def _verify_bnb_config(self) -> None: + """ + The current version of bitsandbytes (0.44.0) with 8-bit models does not + yet support CUDA graph. + """ + is_bitsandbytes = self.quantization == "bitsandbytes" + has_quantization_config = (getattr(self.hf_config, + "quantization_config", None) + is not None) + is_8bit = (self.hf_config.quantization_config.get( + "load_in_8bit", False) if has_quantization_config else False) + if all([ + is_bitsandbytes, + has_quantization_config, + is_8bit, + not self.enforce_eager, + ]): + logger.warning( + "CUDA graph is not supported on BitAndBytes 8bit yet, " + "fallback to the eager mode.") + self.enforce_eager = True + def verify_async_output_proc(self, parallel_config, speculative_config, device_config) -> None: if not self.use_async_output_proc: @@ -401,13 +424,6 @@ def verify_with_parallel_config( "Pipeline parallelism is only supported for the following " f" architectures: {_PP_SUPPORTED_MODELS}.") - # Remove the constraint after the bitsandbytes issue is fixed: - # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308 - if self.quantization == "bitsandbytes" and self.enforce_eager is False: - logger.warning("CUDA graph is not supported on BitAndBytes yet, " - "fallback to the eager mode.") - self.enforce_eager = True - if pipeline_parallel_size > 1 and self.use_async_output_proc: logger.warning("Async output processor is not supported with " "pipeline parallelism currently. Disabling it.") diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 66bc5395dbd7a..38495d5a5a863 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -121,12 +121,12 @@ class BitsAndBytesLinearMethod(LinearMethodBase): def __init__(self, quant_config: BitsAndBytesConfig): try: import bitsandbytes - if bitsandbytes.__version__ < "0.42.0": + if bitsandbytes.__version__ < "0.44.0": raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.42.0.") + "install bitsandbytes>=0.44.0.") except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.42.0 via " - "`pip install bitsandbytes>=0.42.0` to use " + raise ImportError("Please install bitsandbytes>=0.44.0 via " + "`pip install bitsandbytes>=0.44.0` to use " "bitsandbytes quantizer.") from err self.quant_config = quant_config diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index aea3354cada90..c21b10d661ecc 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -851,12 +851,12 @@ def _get_quantized_weights_iterator( # only load the bitsandbytes module when needed try: import bitsandbytes - if bitsandbytes.__version__ < "0.42.0": + if bitsandbytes.__version__ < "0.44.0": raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.42.0.") + "install bitsandbytes>=0.44.0.") except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.42.0 via " - "`pip install bitsandbytes>=0.42.0` to use " + raise ImportError("Please install bitsandbytes>=0.44.0 via " + "`pip install bitsandbytes>=0.44.0` to use " "bitsandbytes quantizer.") from err hf_weights_files, use_safetensors = self._prepare_weights( From 01b6f9e1f0530a7cb81486ff34d3d935e4f75d28 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Tue, 24 Sep 2024 18:29:56 -0600 Subject: [PATCH 21/50] [Core][Bugfix] Support prompt_logprobs returned with speculative decoding (#8047) Signed-off-by: Travis Johnson --- tests/conftest.py | 4 +- tests/spec_decode/e2e/conftest.py | 139 ++++++++++++------ .../spec_decode/e2e/test_eagle_correctness.py | 58 ++++++++ tests/spec_decode/e2e/test_logprobs.py | 95 ++++++------ .../e2e/test_medusa_correctness.py | 59 ++++++++ tests/spec_decode/e2e/test_mlp_correctness.py | 57 ++++++- .../spec_decode/e2e/test_ngram_correctness.py | 59 ++++++++ vllm/engine/output_processor/multi_step.py | 9 +- vllm/model_executor/layers/sampler.py | 11 +- vllm/sequence.py | 2 + vllm/spec_decode/batch_expansion.py | 10 +- vllm/spec_decode/spec_decode_worker.py | 62 ++++++-- vllm/spec_decode/util.py | 45 +++++- vllm/transformers_utils/detokenizer.py | 16 +- 14 files changed, 492 insertions(+), 134 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 69ac4aaee0fda..dcd9afdae3c14 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -675,8 +675,6 @@ def generate_w_logprobs( videos: Optional[PromptVideoInput] = None, ) -> Union[List[TokensTextLogprobs], List[TokensTextLogprobsPromptLogprobs]]: - assert sampling_params.logprobs is not None - if images is not None: assert len(prompts) == len(images) @@ -754,7 +752,7 @@ def generate_greedy_logprobs( temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs, - prompt_logprobs=(num_prompt_logprobs), + prompt_logprobs=num_prompt_logprobs, stop_token_ids=stop_token_ids) return self.generate_w_logprobs(prompts, diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 3d93f4a23b68a..b450ef97c89d4 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,13 +1,16 @@ from itertools import cycle -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple, Union import pytest from vllm import LLM, SamplingParams from vllm.model_executor.utils import set_random_seed +from vllm.sequence import PromptLogprobs, SampleLogprobs from ...conftest import cleanup -from ...models.utils import check_logprobs_close, check_outputs_equal +from ...models.utils import (TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs, + check_logprobs_close, check_outputs_equal) from ...utils import RemoteOpenAIServer PROMPTS = [ @@ -81,45 +84,77 @@ def get_output_from_llm_generator( return tokens, token_ids, acceptance_rate -def run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size: int, - max_output_len: int, - seed: Optional[int] = 0, - temperature: float = 0.0, - logprobs: int = 1): - org_args = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **baseline_llm_kwargs, - } - - sd_args = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **test_llm_kwargs, - } - - prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))] - - sampling_params = SamplingParams(temperature=temperature, - max_tokens=max_output_len, - seed=seed, - logprobs=logprobs) - - with vllm_runner(**org_args) as vllm_model: - org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) - - with vllm_runner(**sd_args) as vllm_model: - sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) - - check_logprobs_close(outputs_0_lst=org_outputs, - outputs_1_lst=sd_outputs, - name_0="org", - name_1="sd") +def check_logprobs_correctness( + spec_outputs: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs]], + baseline_outputs: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs]], + disable_logprobs: bool = False, +): + """Compare sampled and prompt logprobs between baseline and spec decoding + """ + if not disable_logprobs: + return check_logprobs_close( + outputs_0_lst=baseline_outputs, + outputs_1_lst=spec_outputs, + name_0="org", + name_1="sd", + ) + + # Check correctness when disable_logprobs == True + for spec_output, baseline_output in zip(spec_outputs, baseline_outputs): + # Check generated token logprobs. + spec_logprobs = spec_output[2] + baseline_logprobs = baseline_output[2] + _check_logprobs_when_output_disabled(spec_logprobs, + baseline_logprobs, + is_prompt_logprobs=False) + + # Check prompt logprobs too, if they exist + if len(baseline_output) == 4: + assert len(spec_output) == 4 + spec_prompt_logprobs = spec_output[3] + baseline_prompt_logprobs = baseline_output[3] + _check_logprobs_when_output_disabled(spec_prompt_logprobs, + baseline_prompt_logprobs, + is_prompt_logprobs=True) + + +def _check_logprobs_when_output_disabled( + spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs], + baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs], + is_prompt_logprobs: bool = False, +): + # Prompt logprobs are optional + if is_prompt_logprobs and baseline_logprobs is None: + assert spec_logprobs is None + return + + assert spec_logprobs is not None + assert baseline_logprobs is not None + assert len(spec_logprobs) == len(baseline_logprobs) + + # For each generated position of the sequence. + for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate( + zip(spec_logprobs, baseline_logprobs)): + + # First prompt logprob is expected to be None + if is_prompt_logprobs and baseline_pos_logprobs is None: + assert spec_pos_logprobs is None + assert pos == 0 + continue + + assert spec_pos_logprobs is not None + assert baseline_pos_logprobs is not None + + # When disabled, the 1 logprob is returned with dummy values for the + # score and rank, but the token id should match the baseline model + assert len(spec_pos_logprobs) == 1 + (spec_pos_logprob_token_id, + spec_pos_logprob) = next(iter(spec_pos_logprobs.items())) + assert spec_pos_logprob.rank == -1 + assert spec_pos_logprob.logprob == 0.0 + assert spec_pos_logprob_token_id in baseline_pos_logprobs def run_equality_correctness_test( @@ -135,7 +170,10 @@ def run_equality_correctness_test( disable_seed: bool = False, ignore_eos: bool = True, ensure_all_accepted: bool = False, - expected_acceptance_rate: Optional[float] = None): + expected_acceptance_rate: Optional[float] = None, + logprobs: Optional[int] = None, + prompt_logprobs: Optional[int] = None, + disable_logprobs: bool = False): org_args = { **common_llm_kwargs, @@ -157,10 +195,12 @@ def run_equality_correctness_test( sampling_params = SamplingParams(temperature=temperature, max_tokens=max_output_len, seed=seed, - ignore_eos=ignore_eos) + ignore_eos=ignore_eos, + logprobs=logprobs, + prompt_logprobs=prompt_logprobs) with vllm_runner(**org_args) as vllm_model: - org_outputs = vllm_model.generate(prompts, sampling_params) + org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) with vllm_runner(**sd_args) as vllm_model: if ensure_all_accepted or expected_acceptance_rate is not None: @@ -169,7 +209,7 @@ def run_equality_correctness_test( 'prometheus'] stat_logger.local_interval = -100 - sd_outputs = vllm_model.generate(prompts, sampling_params) + sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) if ensure_all_accepted or expected_acceptance_rate is not None: acceptance_rate = (stat_logger.metrics. @@ -185,11 +225,16 @@ def run_equality_correctness_test( if expected_acceptance_rate is not None: assert acceptance_rate >= expected_acceptance_rate - 1e-2 - check_outputs_equal(outputs_0_lst=org_outputs, - outputs_1_lst=sd_outputs, + # Only pass token entries, not the logprobs + check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs], + outputs_1_lst=[out[0:2] for out in sd_outputs], name_0="org", name_1="sd") + # Check logprobs if requested + if logprobs is not None or prompt_logprobs is not None: + check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs) + def run_equality_correctness_test_tp(model, common_llm_kwargs, diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index f2af2c2bedb12..d7ca8815ec259 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -80,6 +80,64 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, batch_size, output_len, seed) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py index 03c1733f104ff..b7d54991e0535 100644 --- a/tests/spec_decode/e2e/test_logprobs.py +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -4,7 +4,7 @@ from vllm import SamplingParams -from .conftest import run_logprob_correctness_test +from .conftest import run_equality_correctness_test @pytest.mark.parametrize( @@ -25,6 +25,10 @@ "speculative_model": "JackFram/llama-160m", "num_speculative_tokens": 3, "disable_logprobs_during_spec_decoding": False, + }, { + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + "disable_logprobs_during_spec_decoding": True, }]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( @@ -41,16 +45,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, seed: int, logprobs: int): """Verify output logprobs are equal with and without speculative decoding. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) @pytest.mark.parametrize( @@ -91,16 +98,18 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, output_len: int, seed: int, logprobs: int): """Veriy logprob greedy equality with different speculation lens. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) @pytest.mark.parametrize( @@ -143,16 +152,18 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, seed: int, logprobs: int): """Verify logprobs greedy equality when some sequences skip speculation. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) @pytest.mark.parametrize( @@ -267,13 +278,15 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs, """Check the behavior when logprobs are disabled. Token choices should match with the base model. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 7cefe99d026c6..8c90e147df23a 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -87,6 +87,65 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 8, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, + seed: int, logprobs: int): + """Verify greedy equality with different batch size.""" + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 2d0d6fb923ad1..7f3180befaffc 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -16,7 +16,7 @@ * Test greedy equality under various number of speculative tokens. With those tests, we can say at least, MLPSpeculator would not break the -correctess for the target model outputs. +correctness for the target model outputs. """ from unittest.mock import patch @@ -88,6 +88,61 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [8]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + """Verify greedy equality with different batch size.""" + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 89301f24e1159..850114eb7f5a8 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -76,6 +76,65 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model_name": "JackFram/llama-68m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 8, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + """Verify greedy equality on a tiny model with different batch size.""" + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index c73db765fc3b5..31c2bbc8e7127 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -9,8 +9,8 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, - SequenceOutput, SequenceStatus) +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup, + SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Counter @@ -110,10 +110,11 @@ def process_outputs(self, # we can take the first sample. samples = [output.samples[0] for output in outputs] - # -1 means the output token is not valid (eg. due to spec decode + # entries in sample tokens may be invalid (eg. due to spec decode # rejecting tokens). valid_samples = [ - sample for sample in samples if sample.output_token != -1 + sample for sample in samples + if sample.output_token != VLLM_INVALID_TOKEN_ID ] assert valid_samples diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 2ca86a4653cf4..583bb02dcb5b4 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -15,7 +15,8 @@ SamplingTensors, SequenceGroupToSample) from vllm.sampling_params import SamplingType -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -759,10 +760,10 @@ def _sample_with_torch( # Create output tensor for sampled token ids. if include_gpu_probs_tensor: - sampled_token_ids_tensor = torch.empty(logprobs.shape[0], - 1, - dtype=torch.long, - device=logprobs.device) + sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1), + VLLM_INVALID_TOKEN_ID, + dtype=torch.long, + device=logprobs.device) else: sampled_token_ids_tensor = None diff --git a/vllm/sequence.py b/vllm/sequence.py index 79e8a1f6244d7..b32e1aebe17be 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -26,6 +26,8 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l" +VLLM_INVALID_TOKEN_ID = -1 + # We use dataclass for now because it is used for # openai server output, and msgspec is not serializable. diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index b2204e8b27afd..9eb8bbfc54076 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -6,9 +6,9 @@ from vllm import SamplingParams from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest, - SequenceData, SequenceGroupMetadata, - get_all_seq_ids) +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE, + ExecuteModelRequest, SequenceData, + SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len @@ -69,10 +69,10 @@ def score_proposals( proposal_lens_list = proposals.proposal_lens.tolist() proposal_token_ids_list = proposals.proposal_token_ids.tolist() - # Filter the list to ignore -1 proposals. + # Filter the list to ignore invalid proposals. proposal_token_ids_list_without_skips = [ proposals for proposals in proposal_token_ids_list - if -1 not in proposals + if VLLM_INVALID_TOKEN_ID not in proposals ] (spec_indices, non_spec_indices, target_seq_group_metadata_list, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 9e645a49f699c..dbf880a8f475c 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -13,9 +13,10 @@ SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) -from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SequenceGroupMetadata, - get_all_seq_ids, get_all_seq_ids_and_request_ids) + get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, @@ -28,7 +29,8 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker from vllm.spec_decode.target_model_runner import TargetModelRunner -from vllm.spec_decode.util import (Timer, create_sequence_group_output, +from vllm.spec_decode.util import (Timer, create_logprobs_output, + create_sequence_group_output, get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) @@ -436,8 +438,8 @@ def _serialize_sampler_output_no_logprobs( self, execute_model_req: ExecuteModelRequest, sampler_output: SamplerOutput) -> SamplerOutput: """ - Creates and returns a `SamplerOutput` with only the sampled token IDs - being serialized to CPU & populated in `CompletionSequenceGroupOutput`. + Creates and returns a `SamplerOutput` with only the token IDs being + serialized to CPU and populated in `CompletionSequenceGroupOutput`. All other parameters in `CompletionSequenceGroupOutput` related to log probabilities are skipped. @@ -449,14 +451,46 @@ def _serialize_sampler_output_no_logprobs( Returns: SamplerOutput: A new `SamplerOutput` instance containing a list of - `CompletionSequenceGroupOutput` objects with only sampled token - IDs populated. + `CompletionSequenceGroupOutput` objects with only token IDs + populated. """ - seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list) - sampled_token_ids_list = sampler_output.sampled_token_ids.tolist() + seq_output_prompt_logprobs = [ + seq.is_prompt and seq.sampling_params.prompt_logprobs is not None + and seq.sampling_params.prompt_logprobs > 0 + for seq in execute_model_req.seq_group_metadata_list + ] + # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID + sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where( + # subtracting is faster than testing for equality + sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \ + if any(seq_output_prompt_logprobs) else \ + sampler_output.sampled_token_ids).tolist() + + seq_data_entries = ( + (seq_id, seq_data) for sg in \ + execute_model_req.seq_group_metadata_list \ + for seq_id, seq_data in sg.seq_data.items() + ) completion_seq_group_output_list: List[ CompletionSequenceGroupOutput] = [] - for index, seq_id in enumerate(seq_ids): + for index, ((seq_id, seq_data), needs_prompt_logprobs) in \ + enumerate(zip(seq_data_entries, seq_output_prompt_logprobs)): + if needs_prompt_logprobs: + prompt_token_ids = seq_data.get_prompt_token_ids() + prompt_logprobs = [ + create_logprobs_output( + token_id=p_token_id, + token_id_logprob_rank=-1, + token_id_logprob=0.0, + topk_token_ids=[], + topk_logprobs=[], + ) + # no prompt logprobs for the first token + for p_token_id in prompt_token_ids[1:] + ] + else: + prompt_logprobs = None + completion_seq_group_output_list.append( create_sequence_group_output( token_id=sampled_token_ids_list[index][0], @@ -465,7 +499,7 @@ def _serialize_sampler_output_no_logprobs( seq_id=seq_id, topk_token_ids=[], topk_logprobs=[], - )) + prompt_logprobs=prompt_logprobs)) return SamplerOutput(outputs=completion_seq_group_output_list) @nvtx_range("spec_decode_worker._run_no_spec") @@ -485,6 +519,12 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, # Store hidden states from target model execution. hidden_states = sampler_output.hidden_states if hidden_states is not None: + # remove hidden_states for prompt tokens + if any(seq.is_prompt + for seq in execute_model_req.seq_group_metadata_list): + hidden_states = hidden_states[ + torch.where(sampler_output.sampled_token_ids - + VLLM_INVALID_TOKEN_ID)[0]] if self.previous_hidden_states is None: self.previous_hidden_states = HiddenStates( hidden_states, execute_model_req.seq_group_metadata_list) diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 54e718bc49017..193ef870dfceb 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -6,7 +6,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceGroupMetadata, SequenceOutput) + PromptLogprobs, SequenceGroupMetadata, + SequenceOutput) SeqId = int @@ -49,21 +50,19 @@ def get_sampled_token_logprobs( return sampled_token_ids_ranks, selected_logprobs -def create_sequence_group_output( +def create_logprobs_output( token_id: int, token_id_logprob_rank: int, token_id_logprob: float, - seq_id: SeqId, topk_token_ids: List[Optional[int]], topk_logprobs: List[Optional[float]], -) -> CompletionSequenceGroupOutput: - """Create a SequenceGroupOutput given the sampling results. +) -> Dict[int, Logprob]: + """Create a Logprob Dict for a token given the sampling results. Args: token_id (int): The sampled token for the sequence. token_id_logprob_rank (int): The logprob rank of the sampled token. token_id_logprob (float): The logprob value of the sampled token. - seq_id (int): The sequence id. topk_token_ids (List[Optional[int]]): The list of top-k token ids. topk_logprobs (List[Optional[float]]): The list of top-k logprobs. """ @@ -85,14 +84,44 @@ def create_sequence_group_output( if topk_token_id is not None }) + return logprobs + + +def create_sequence_group_output( + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + seq_id: SeqId, + topk_token_ids: List[Optional[int]], + topk_logprobs: List[Optional[float]], + prompt_logprobs: Optional[PromptLogprobs] = None, +) -> CompletionSequenceGroupOutput: + """Create a SequenceGroupOutput given the sampling results. + + Args: + token_id (int): The sampled token for the sequence. + token_id_logprob_rank (int): The logprob rank of the sampled token. + token_id_logprob (float): The logprob value of the sampled token. + seq_id (int): The sequence id. + topk_token_ids (List[Optional[int]]): The list of top-k token ids. + topk_logprobs (List[Optional[float]]): The list of top-k logprobs. + """ + + logprobs = create_logprobs_output( + token_id, + token_id_logprob_rank, + token_id_logprob, + topk_token_ids, + topk_logprobs, + ) + return CompletionSequenceGroupOutput( samples=[ SequenceOutput(parent_seq_id=seq_id, output_token=token_id, logprobs=logprobs) ], - # TODO add prompt logprobs support. - prompt_logprobs=None, + prompt_logprobs=prompt_logprobs, ) diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index d27d7ba9e67bb..2b418f3603a0b 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -1,13 +1,11 @@ from typing import Dict, List, Optional, Tuple -from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, + Sequence, SequenceGroup) from .tokenizer import AnyTokenizer from .tokenizer_group import BaseTokenizerGroup -# Used eg. for marking rejected tokens in spec decoding. -INVALID_TOKEN_ID = -1 - class Detokenizer: """Provides methods to decode the output of a model into text.""" @@ -61,7 +59,7 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, continue for token_id, sample_logprob in prompt_logprobs_for_token.items(): if (sample_logprob.decoded_token is None - and token_id != INVALID_TOKEN_ID): + and token_id != VLLM_INVALID_TOKEN_ID): prompt_token_ids_with_token = ( prompt_token_ids[:token_position] + [token_id]) (new_tokens, new_text, new_prefix_offset, @@ -143,7 +141,7 @@ def decode_sequence_inplace(self, seq: Sequence, continue if (sample_logprob.decoded_token is None - and token_id != INVALID_TOKEN_ID): + and token_id != VLLM_INVALID_TOKEN_ID): all_input_ids_with_logprob = previous_tokens + [token_id] (_, new_text, _, _) = detokenize_incrementally( tokenizer=tokenizer, @@ -282,14 +280,14 @@ def detokenize_incrementally( assert prev_tokens is not None # If the new token id is out of bounds, return an empty string. - if new_token_id >= len(tokenizer): - new_tokens = [""] - else: + if 0 <= new_token_id < len(tokenizer): # Put new_token_id in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( [new_token_id], skip_special_tokens=skip_special_tokens) if isinstance(new_tokens, str): new_tokens = [new_tokens] + else: + new_tokens = [""] output_tokens = prev_tokens + new_tokens # If this is the first iteration, return all tokens. From 6da1ab6b4134d76391a0c31a048e5d04b6283769 Mon Sep 17 00:00:00 2001 From: Archit Patke Date: Tue, 24 Sep 2024 21:50:50 -0500 Subject: [PATCH 22/50] [Core] Adding Priority Scheduling (#5958) --- benchmarks/benchmark_prioritization.py | 295 +++++++++++++++++++++++++ vllm/config.py | 6 +- vllm/core/scheduler.py | 77 +++++++ vllm/engine/llm_engine.py | 24 +- vllm/entrypoints/llm.py | 12 +- vllm/sequence.py | 4 + 6 files changed, 410 insertions(+), 8 deletions(-) create mode 100644 benchmarks/benchmark_prioritization.py diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py new file mode 100644 index 0000000000000..0ba29fabca59b --- /dev/null +++ b/benchmarks/benchmark_prioritization.py @@ -0,0 +1,295 @@ +"""Benchmark offline prioritization.""" +import argparse +import json +import random +import time +from typing import List, Optional, Tuple + +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int], +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data["conversations"][0]["value"], + data["conversations"][1]["value"]) for data in dataset] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + + #Select a equi-probable random priority + priority = 0 if random.random() < 0.5 else 1 + + filtered_dataset.append((prompt, prompt_len, output_len, priority)) + + return filtered_dataset + + +def run_vllm( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: str, + quantization: Optional[str], + tensor_parallel_size: int, + seed: int, + n: int, + use_beam_search: bool, + trust_remote_code: bool, + dtype: str, + max_model_len: Optional[int], + enforce_eager: bool, + kv_cache_dtype: str, + quantization_param_path: Optional[str], + device: str, + enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, + gpu_memory_utilization: float = 0.9, + download_dir: Optional[str] = None, +) -> float: + from vllm import LLM, SamplingParams + llm = LLM( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + disable_log_stats=False, + ) + + # Add the requests to the engine. + prompts = [] + sampling_params = [] + priority = [] + for prompt, _, output_len, _priority in requests: + prompts.append(prompt) + priority.append(_priority) + sampling_params.append( + SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + )) + + start = time.perf_counter() + llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True) + end = time.perf_counter() + return end - start + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + if args.dataset is None: + # Synthesize a prompt with the given input length. + prompt = "hi" * (args.input_len - 1) + requests = [(prompt, args.input_len, args.output_len) + for _ in range(args.num_prompts)] + else: + requests = sample_requests(args.dataset, args.num_prompts, tokenizer, + args.output_len) + + if args.backend == "vllm": + elapsed_time = run_vllm( + requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.gpu_memory_utilization, + args.download_dir) + else: + raise ValueError(f"Unknown backend: {args.backend}") + total_num_tokens = sum(prompt_len + output_len + for _, prompt_len, output_len, priority in requests) + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark the throughput.") + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii"], + default="vllm") + parser.add_argument("--dataset", + type=str, + default=None, + help="Path to the dataset.") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument('--quantization', + '-q', + choices=[*QUANTIZATION_METHODS, None], + default=None) + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--num-prompts", + type=int, + default=200, + help="Number of prompts to process.") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument( + '--max-model-len', + type=int, + default=None, + help='Maximum length of a sequence (including prompt and output). ' + 'If None, will be derived from the model.') + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') + parser.add_argument("--enforce-eager", + action="store_true", + help="enforce eager execution") + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default="auto", + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache scaling factors. ' + 'This should generally be supplied, when KV cache dtype is FP8. ' + 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' + 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' + 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') + parser.add_argument( + "--device", + type=str, + default="cuda", + choices=["cuda", "cpu"], + help='device type for vLLM execution, supporting CUDA and CPU.') + parser.add_argument( + "--enable-prefix-caching", + action='store_true', + help="enable automatic prefix caching for vLLM backend.") + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') + parser.add_argument('--download-dir', + type=str, + default=None, + help='directory to download and load the weights, ' + 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + if args.dataset is None: + assert args.input_len is not None + assert args.output_len is not None + else: + assert args.input_len is None + + main(args) diff --git a/vllm/config.py b/vllm/config.py index 562564bbfa032..308f29a3dc371 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -961,7 +961,7 @@ class SchedulerConfig: workers instead of an entire data. It should be enabled only when SPMD worker architecture is enabled. I.e., VLLM_USE_RAY_SPMD_WORKER=1 - + policy: The scheduling policy to use. "fcfs" (default) or "priority". """ def __init__(self, @@ -977,7 +977,8 @@ def __init__(self, preemption_mode: Optional[str] = None, num_scheduler_steps: int = 1, multi_step_stream_outputs: bool = False, - send_delta_data: bool = False) -> None: + send_delta_data: bool = False, + policy: str = "fcfs") -> None: if max_num_batched_tokens is None: if enable_chunked_prefill: # It is the values that have the best balance between ITL @@ -1019,6 +1020,7 @@ def __init__(self, self.num_scheduler_steps = num_scheduler_steps self.multi_step_stream_outputs = multi_step_stream_outputs self.send_delta_data = send_delta_data + self.policy = policy self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c3fa95f57b737..b707d87c3af83 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -766,6 +766,79 @@ def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: else: return prompt_limit + def _get_priority(self, + seq_group: SequenceGroup) -> Tuple[Optional[int], float]: + """ Get the priority of the sequence group. + Highest preference to user-defined priority, followed by arrival time. + Args: + seq_group: The sequence group input. + Returns: + The priority of the sequence group. + """ + return seq_group.priority, seq_group.arrival_time + + def _schedule_priority_preemption( + self, + budget: SchedulingBudget, + ) -> int: + """Sorts waiting and running queue. Also, force preempt requests + from the running queue if their priority is lower. + Priority-based preemption is used with the priority policy. + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + Returns: + A count of priority-based preemptions. + """ + + waiting_queue = self.waiting + + running_queue = deque(sorted(self.running, key=self._get_priority)) + + blocks_to_swap_out: List[Tuple[int, int]] = [] + force_preemption_count = 0 + + if waiting_queue: + seq_group = waiting_queue.popleft() + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens = self._get_num_new_tokens(seq_group, + SequenceStatus.WAITING, + False, budget) + + #Only preempt if priority inversion exists + while running_queue and self._get_priority( + running_queue[-1]) > self._get_priority(seq_group): + #Only preempt if waiting sequence cannot be allocated + can_allocate = self.block_manager.can_allocate(seq_group) + if (num_new_tokens and can_allocate == AllocStatus.OK + and budget.can_schedule(num_new_tokens=num_new_tokens, + num_new_seqs=num_new_seqs)): + break + + #Adjust budget to remove the victim sequence group + vseq_group = running_queue.pop() + num_running_tokens = self._get_num_new_tokens( + vseq_group, SequenceStatus.RUNNING, False, budget) + budget.subtract_num_batched_tokens(vseq_group.request_id, + num_running_tokens) + num_running_seqs = vseq_group.get_max_num_running_seqs() + budget.subtract_num_seqs(vseq_group.request_id, + num_running_seqs) + + #Preempt out the victim sequence group + self._preempt(vseq_group, blocks_to_swap_out, + PreemptionMode.RECOMPUTE) + waiting_queue.appendleft(vseq_group) + force_preemption_count += 1 + #Put the sequence back into the waiting queue + waiting_queue.appendleft(seq_group) + + waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) + + self.waiting = waiting_queue + self.running = running_queue + return force_preemption_count + def _schedule_prefills( self, budget: SchedulingBudget, @@ -917,6 +990,10 @@ def _schedule_default(self) -> SchedulerOutputs: curr_loras, enable_chunking=False) + if len(prefills.seq_groups + ) == 0 and self.scheduler_config.policy == "priority": + self._schedule_priority_preemption(budget) + # Don't schedule decodes if prefills are scheduled. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # only contains decode requests, not chunked prefills. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bd7b3250e31af..c341b236003a3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -631,6 +631,7 @@ def _add_processed_request( lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, ) -> None: self._validate_model_inputs(processed_inputs) # Create the sequences. @@ -661,7 +662,8 @@ def _add_processed_request( lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( request_id, @@ -670,7 +672,8 @@ def _add_processed_request( arrival_time=arrival_time, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) else: raise ValueError( "Either SamplingParams or PoolingParams must be provided.") @@ -695,6 +698,7 @@ def add_request( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: """Add a request to the engine's request pool. @@ -713,6 +717,8 @@ def add_request( arrival_time: The arrival time of the request. If None, we use the current monotonic time. trace_headers: OpenTelemetry trace headers. + priority: The priority of the request. + Only applicable with priority scheduling. Details: - Set arrival_time to the current time if it is None. @@ -741,6 +747,11 @@ def add_request( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + + if priority > 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + if arrival_time is None: arrival_time = time.time() @@ -760,6 +771,7 @@ def add_request( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, + priority=priority, ) def _create_sequence_group_with_sampling( @@ -772,6 +784,7 @@ def _create_sequence_group_with_sampling( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, encoder_seq: Optional[Sequence] = None, + priority: int = 0, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -798,7 +811,8 @@ def _create_sequence_group_with_sampling( lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) return seq_group @@ -811,6 +825,7 @@ def _create_sequence_group_with_pooling( lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], encoder_seq: Optional[Sequence] = None, + priority: int = 0, ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler @@ -823,7 +838,8 @@ def _create_sequence_group_with_pooling( lora_request=lora_request, pooling_params=pooling_params, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index cd10eda8c212c..77ae7b088398a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -320,7 +320,8 @@ def generate( lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, - GuidedDecodingRequest]] = None + GuidedDecodingRequest]] = None, + priority: Optional[List[int]] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -339,6 +340,8 @@ def generate( lora_request: LoRA request to use for generation, if any. prompt_adapter_request: Prompt Adapter request to use for generation, if any. + priority: The priority of the requests, if any. + Only applicable when priority scheduling policy is enabled. Returns: A list of ``RequestOutput`` objects containing the @@ -379,7 +382,8 @@ def generate( params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, - guided_options=guided_options_request) + guided_options=guided_options_request, + priority=priority) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) @@ -782,6 +786,7 @@ def _validate_and_add_requests( lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], guided_options: Optional[GuidedDecodingRequest] = None, + priority: Optional[List[int]] = None, ) -> None: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. @@ -811,6 +816,7 @@ def _validate_and_add_requests( lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, prompt_adapter_request=prompt_adapter_request, + priority=priority[i] if priority else 0, ) def _add_request( @@ -819,6 +825,7 @@ def _add_request( params: Union[SamplingParams, PoolingParams], lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request( @@ -827,6 +834,7 @@ def _add_request( params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, + priority=priority, ) def _add_guided_processor( diff --git a/vllm/sequence.py b/vllm/sequence.py index b32e1aebe17be..fda7ef87749a1 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -646,6 +646,7 @@ class SequenceGroup: unless you are working with an encoder/decoder model. trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request. + priority: User-defined priority of the request. """ def __init__( @@ -660,9 +661,11 @@ def __init__( encoder_seq: Optional[Sequence] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: self.request_id = request_id self.seqs = seqs + self.arrival_time = arrival_time self.is_single_seq = len(seqs) == 1 self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -680,6 +683,7 @@ def __init__( self.prompt_adapter_request = prompt_adapter_request self.encoder_seq = encoder_seq self.trace_headers = trace_headers + self.priority = priority self.cached_request_output = None From 6e0c9d6bd07464b311eb098e2dac8196eed16721 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 24 Sep 2024 21:37:38 -0600 Subject: [PATCH 23/50] [Bugfix] Use heartbeats instead of health checks (#8583) --- tests/mq_llm_engine/test_error_handling.py | 15 ++--- vllm/engine/multiprocessing/__init__.py | 7 +- vllm/engine/multiprocessing/client.py | 51 +++++++------- vllm/engine/multiprocessing/engine.py | 77 +++++++++++++++++----- 4 files changed, 87 insertions(+), 63 deletions(-) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 49cfc5aa04c36..76b2f494d5b25 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -153,27 +153,20 @@ async def test_failed_abort(tmp_socket): await client.check_health() # Trigger an abort on the client side. - async def bad_abort_after_2s(): - await asyncio.sleep(2.0) - await client.abort(request_id="foo") + # This request ID does not exist, and will cause the engine to error + await client.abort(request_id="foo") - # Trigger an abort in 2s from now. - abort_task = asyncio.create_task(bad_abort_after_2s()) - - # Exception in abort() will happen during this generation. - # This will kill the engine and should return ENGINE_DEAD_ERROR + # Future generation requests will now fail # with reference to the original KeyError("foo") with pytest.raises(MQEngineDeadError) as execinfo: async for _ in client.generate( inputs="Hello my name is", - sampling_params=SamplingParams(max_tokens=2000), + sampling_params=SamplingParams(max_tokens=10), request_id=uuid.uuid4()): pass assert "KeyError" in repr(execinfo.value) assert client.errored - await abort_task - # This should raise the original error. with pytest.raises(RAISED_ERROR): await client.check_health() diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 700332864d17a..165e6cc2146c3 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -43,10 +43,6 @@ class RPCAbortRequest: request_id: str -class RPCHealthRequest: - pass - - class RPCStartupRequest(Enum): IS_SERVER_READY = 1 @@ -56,8 +52,7 @@ class RPCStartupResponse: tracing_enabled: bool -RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest, - RPCStartupRequest] +RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index aa9dbbd448af2..7e397cf408fba 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -20,9 +20,8 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCHealthRequest, - RPCProcessRequest, RPCStartupRequest, - RPCStartupResponse) + RPCError, RPCProcessRequest, + RPCStartupRequest, RPCStartupResponse) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT from vllm.inputs import PromptInputs @@ -95,9 +94,9 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig): self.output_socket: Socket = self.context.socket(zmq.constants.PULL) self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") - # IPC path for ack of check_health requests. - self.health_socket: Socket = self.context.socket(zmq.constants.PULL) - self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + # IPC path for acking heartbeats. + self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) + self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" @@ -124,34 +123,28 @@ def get_data_socket(self) -> Iterator[Socket]: finally: socket.close(linger=0) - async def run_check_health_loop(self, timeout: int): - """Background loop that continually probes the RPCServer for health. - - The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which - the MQLLMEngine server is blocking on. - - The Server replies on the HEALTH_SOCKET (rather than on the - OUTPUT_SOCKET such that the messages are not intermingled with - output streaming). + async def run_heartbeat_loop(self, timeout: int): + """Background loop that continually listens to the RPCServer for + heartbeats. """ - try: while True: - if await self.health_socket.poll(timeout=timeout) == 0: - # Wakeup every N seconds and do a health probe. - await self._send_one_way_rpc_request( - RPCHealthRequest(), self.input_socket) - - # Wait for ack from the health socket. - await self._await_ack(error_message="Health check failed.", - socket=self.health_socket) + if await self.heartbeat_socket.poll(timeout=timeout) == 0: + # No heartbeat was received. Set error and exit the loop + self._set_errored( + TimeoutError("No heartbeat received " + "from MQLLMEngine")) + logger.debug("Shutting down MQLLMEngineClient check " + "health loop due to timeout") + break + else: - # Server sent a health status message unprompted. + # Heartbeat received- check the message await self._check_success( - error_message="Health check failed.", - socket=self.health_socket) + error_message="Heartbeat failed.", + socket=self.heartbeat_socket) - logger.debug("Health probe successful.") + logger.debug("Heartbeat successful.") except asyncio.CancelledError: logger.debug("Shutting down MQLLMEngineClient check health loop.") @@ -234,7 +227,7 @@ async def setup(self): # Start health_loop. self.health_loop = asyncio.create_task( - self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT)) + self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) def close(self): """Destroy the ZeroMQ Context.""" diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 485db0bab1297..b1dd9915cbbf5 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -1,5 +1,7 @@ import pickle import signal +import threading +import time from contextlib import contextmanager from typing import Iterator, List, Optional, Union @@ -15,10 +17,10 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCHealthRequest, - RPCProcessRequest, RPCStartupRequest, - RPCStartupResponse) + RPCError, RPCProcessRequest, + RPCStartupRequest, RPCStartupResponse) # yapf: enable +from vllm.envs import VLLM_RPC_TIMEOUT from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -91,9 +93,9 @@ def __init__(self, self.output_socket = self.ctx.socket(zmq.constants.PUSH) self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") - # Send health status back to client. - self.health_socket = self.ctx.socket(zmq.constants.PUSH) - self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + # Send heartbeats back to client. + self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) + self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" @@ -101,6 +103,20 @@ def __init__(self, # Error state. self._errored_with: Optional[BaseException] = None + # Heartbeat thread + self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop, + daemon=True) + self._heartbeat_stop_event = threading.Event() + # The heartbeat needs to be faster than what the client will wait for + # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds + self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0 + + self._last_alive_time = time.time() + # The heartbeats can tolerate a long period of the engine chugging + # away at a generation request. + # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds + self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0 + @property def dead_error(self) -> BaseException: if self._errored_with is not None: @@ -131,6 +147,8 @@ def start(self): try: logger.debug("Starting Startup Loop.") self.run_startup_loop() + logger.debug("Starting heartbeat thread") + self.heartbeat_thread.start() logger.debug("Starting Engine Loop.") self.run_engine_loop() except Exception as e: @@ -144,6 +162,7 @@ def start(self): def cleanup(self): """Cleanup zeromq state on shutdown.""" # Closes all sockets and destroys context. + self._heartbeat_stop_event.set() self.ctx.destroy(linger=0) del self.engine @@ -182,9 +201,11 @@ def run_engine_loop(self): """Core busy loop of the LLMEngine.""" while True: + self._alive() if not self.engine.has_unfinished_requests(): # Poll until there is work to do. while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + self._alive() self.engine.do_log_stats() logger.debug("Waiting for new requests in engine loop.") @@ -200,7 +221,6 @@ def run_engine_loop(self): def engine_step(self) -> List[RequestOutput]: """Engine step wrapper with error handling.""" - try: return self.engine.step() except SystemExit: @@ -229,10 +249,9 @@ def handle_new_input(self): self._handle_process_request(request) elif isinstance(request, RPCAbortRequest): self._handle_abort_request(request) - elif isinstance(request, RPCHealthRequest): - self._handle_health_request() else: - raise ValueError("Unknown RPCRequest Type: {request}") + raise ValueError("Unknown RPCRequest Type: " + f"{type(request)}") except Exception as e: self._set_errored(e) @@ -279,13 +298,32 @@ def _handle_abort_request(self, request: RPCAbortRequest): if self.log_requests: logger.info("Aborted request %s.", request.request_id) - def _handle_health_request(self): + def _heartbeat_loop(self): + while not self._heartbeat_stop_event.wait( + timeout=self.heartbeat_interval_seconds): + # Loops until the stop event is set + self._heartbeat() + + logger.debug("Exiting MQLLMEngine heartbeat thread") + + def _heartbeat(self): + # Send unhealthy if engine has already errored if self._errored_with is not None: self._send_unhealthy(self._errored_with) - # Raises error if unhealthy. - self.engine.check_health() - self._send_healthy() + # Check for life of the main loop + elif time.time() - self._last_alive_time > self.last_alive_threshold: + self._send_unhealthy(RuntimeError("Engine loop has died")) + + else: + # Otherwise- check health of the engine + # self.engine.check_health() raises on unhealthy + try: + self.engine.check_health() + self._send_healthy() + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): """Send List of RequestOutput to RPCClient.""" @@ -295,12 +333,14 @@ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): def _send_healthy(self): """Send HEALTHY message to RPCClient.""" - self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + if not self.heartbeat_socket.closed: + self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False) def _send_unhealthy(self, error: BaseException): """Send UNHEALTHY message to RPCClient.""" - error_bytes = pickle.dumps(error) - self.health_socket.send_multipart((error_bytes, ), copy=False) + if not self.heartbeat_socket.closed: + error_bytes = pickle.dumps(error) + self.heartbeat_socket.send_multipart((error_bytes, ), copy=False) def _async_socket_engine_callback(self, request_outputs: REQUEST_OUTPUTS_T): @@ -313,6 +353,9 @@ def _set_errored(self, e: BaseException): if self._errored_with is None: self._errored_with = e + def _alive(self): + self._last_alive_time = time.time() + def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): From ee777d9c30418ffa9d98f98dd27c0ddea346c49c Mon Sep 17 00:00:00 2001 From: sroy745 <142070531+sroy745@users.noreply.github.com> Date: Tue, 24 Sep 2024 21:26:18 -0700 Subject: [PATCH 24/50] Fix test_schedule_swapped_simple in test_scheduler.py (#8780) --- tests/core/test_scheduler.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index b3bc00280682c..88c6c3bb28e43 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -747,13 +747,19 @@ def test_schedule_decode_blocks_to_copy_update(use_v2_block_manager: bool): assert output.blocks_to_copy == [(2, 3)] -def test_schedule_swapped_simple(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_simple(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt("1", + prompt_length=4, + best_of=2, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) + append_new_token_seq_group(4, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) scheduler._add_seq_group_to_swapped(seq_group) From b4522474a32b6e0bf5573a9b6a6830cb787dfb63 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Wed, 25 Sep 2024 04:26:33 +0000 Subject: [PATCH 25/50] [Bugfix][Kernel] Implement acquire/release polyfill for Pascal (#8776) --- csrc/custom_all_reduce.cuh | 11 +++++++++++ csrc/custom_all_reduce_test.cu | 7 +++++++ 2 files changed, 18 insertions(+) diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 632b579c55afa..a2f7e43300002 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -131,15 +131,26 @@ DINLINE O downcast(array_t val) { } static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +#else + asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), + "l"(flag_addr)); +#endif } static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) { FlagType flag; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); +#else + asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" + : "=r"(flag) + : "l"(flag_addr)); +#endif return flag; } diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index c8b5d0a013f63..376687e91cfda 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -44,7 +44,14 @@ } while (0) __global__ void dummy_kernel() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms +#else + for (int i = 0; i < 100; i++) { + long long int start = clock64(); + while (clock64() - start < 150000000); // approximately 98.4ms on P40 + } +#endif } template From fc3afc20df410dd523f94967b98836084f561ab7 Mon Sep 17 00:00:00 2001 From: sroy745 <142070531+sroy745@users.noreply.github.com> Date: Tue, 24 Sep 2024 21:26:36 -0700 Subject: [PATCH 26/50] Fix tests in test_chunked_prefill_scheduler which fail with BlockManager V2 (#8752) --- tests/core/test_chunked_prefill_scheduler.py | 225 ++++++++++++------- 1 file changed, 143 insertions(+), 82 deletions(-) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 2f6ea632a5d9b..9dddd751c7858 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -27,16 +27,19 @@ def schedule_and_update_computed_tokens(scheduler): return metas, out -def test_simple(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_simple(use_v2_block_manager: bool): """Verify basic scheduling works.""" block_size = 4 num_seq_group = 4 max_model_len = 16 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - num_seq_group, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + num_seq_group, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -45,7 +48,9 @@ def test_simple(): # Add seq groups to scheduler. for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -69,30 +74,36 @@ def test_simple(): assert len(seq_group_meta) == num_seq_group -def test_chunk(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_chunk(use_v2_block_manager: bool): """Verify prefills are chunked properly.""" block_size = 4 max_seqs = 60 max_model_len = 80 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 32 + cache_config.num_gpu_blocks = 32 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(2): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) # Verify the second request is chunked. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + print() assert set(get_sequence_groups(out)) == set(running) assert seq_group_meta[0].token_chunk_size == 60 # Verify it is chunked. @@ -113,24 +124,29 @@ def test_chunk(): assert out.num_batched_tokens == 57 -def test_complex(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_complex(use_v2_block_manager: bool): block_size = 4 max_seqs = 60 max_model_len = 80 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 64 + cache_config.num_gpu_blocks = 64 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(2): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -151,7 +167,9 @@ def test_complex(): # Add 2 more requests. for i in range(2, 4): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -176,16 +194,19 @@ def test_complex(): assert running[2].is_prefill() -def test_maximal_decoding(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_maximal_decoding(use_v2_block_manager: bool): """Verify decoding requests are prioritized.""" block_size = 4 max_seqs = 2 max_model_len = 8 max_num_batched_tokens = 2 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -194,7 +215,9 @@ def test_maximal_decoding(): # Add seq groups to scheduler. for i in range(2): - _, seq_group = create_dummy_prompt(str(i), prompt_length=2) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=2, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -211,7 +234,9 @@ def test_maximal_decoding(): append_new_token(running[0], 1) # Create one more seq_group. - _, seq_group = create_dummy_prompt("3", prompt_length=2) + _, seq_group = create_dummy_prompt("3", + prompt_length=2, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -263,23 +288,28 @@ def test_maximal_decoding(): assert out.num_batched_tokens == 2 -def test_prompt_limit(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prompt_limit(use_v2_block_manager: bool): """Verify max_num_batched_tokens < max_model_len is possible.""" block_size = 4 max_seqs = 32 max_model_len = 64 max_num_batched_tokens = 32 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] - _, seq_group = create_dummy_prompt("1", prompt_length=48) + _, seq_group = create_dummy_prompt("1", + prompt_length=48, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -293,7 +323,8 @@ def test_prompt_limit(): assert out.num_batched_tokens == 32 -def test_prompt_limit_exceed(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prompt_limit_exceed(use_v2_block_manager: bool): block_size = 4 max_seqs = 64 max_model_len = 32 @@ -303,12 +334,13 @@ def test_prompt_limit_exceed(): max_model_len, enable_chunked_prefill=True) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] - - _, seq_group = create_dummy_prompt("2", prompt_length=48) + _, seq_group = create_dummy_prompt("2", + prompt_length=48, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -317,22 +349,28 @@ def test_prompt_limit_exceed(): assert out.ignored_seq_groups[0] == seq_group -def test_swap(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_swap(use_v2_block_manager: bool): """Verify swapping works with chunked prefill requests""" block_size = 4 max_seqs = 30 max_model_len = 200 max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + best_of=2, + block_size=block_size) scheduler.add_seq_group(seq_group) _, out = schedule_and_update_computed_tokens(scheduler) # The request is chunked. @@ -369,21 +407,27 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert out.blocks_to_swap_out == [] -def test_running_prefill_prioritized_over_swap(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_running_prefill_prioritized_over_swap(use_v2_block_manager: bool): block_size = 4 max_seqs = 30 max_model_len = 200 max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 32 + cache_config.num_gpu_blocks = 32 scheduler = Scheduler(scheduler_config, cache_config, None) - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + best_of=2, + block_size=block_size) scheduler.add_seq_group(seq_group) _, out = schedule_and_update_computed_tokens(scheduler) # The request is chunked. @@ -413,7 +457,9 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): scheduler.block_manager.can_swap_in = MagicMock() scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER - _, seq_group2 = create_dummy_prompt("2", prompt_length=60) + _, seq_group2 = create_dummy_prompt("2", + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group2) _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 1 @@ -455,22 +501,27 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert out.blocks_to_swap_out == [] -def test_chunked_prefill_preempt(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_chunked_prefill_preempt(use_v2_block_manager: bool): """Verify preempt works with chunked prefill requests""" block_size = 4 max_seqs = 30 max_model_len = 200 max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) - _, seq_group = create_dummy_prompt("1", prompt_length=60) + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) _, out = schedule_and_update_computed_tokens(scheduler) # The request is chunked. @@ -517,22 +568,27 @@ def cannot_append_second_group2(seq_group, num_lookahead_slots): assert out.num_batched_tokens == max_num_batched_tokens -def test_chunked_prefill_max_seqs(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_chunked_prefill_max_seqs(use_v2_block_manager: bool): block_size = 4 max_seqs = 2 max_model_len = 80 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 128 + cache_config.num_gpu_blocks = 128 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] - _, seq_group = create_dummy_prompt("1", prompt_length=65) + _, seq_group = create_dummy_prompt("1", + prompt_length=65, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) # The first prefill is chunked. @@ -542,7 +598,9 @@ def test_chunked_prefill_max_seqs(): # Add new requests. for i in range(4): - _, seq_group = create_dummy_prompt(str(i), prompt_length=65) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=65, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -564,16 +622,19 @@ def test_chunked_prefill_max_seqs(): assert not running[1].is_prefill() -def test_perfix_caching(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_perfix_caching(use_v2_block_manager: bool): """Verify allocating full blocks when prefix caching is enabled.""" block_size = 4 max_seqs = 10 max_model_len = 80 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, From e3dd0692fa2c803cd6f59a88d2fdf8bca26d8d96 Mon Sep 17 00:00:00 2001 From: zifeitong Date: Tue, 24 Sep 2024 22:53:43 -0700 Subject: [PATCH 27/50] [BugFix] Propagate 'trust_remote_code' setting in internvl and minicpmv (#8250) --- vllm/model_executor/models/internvl.py | 15 +-- vllm/model_executor/models/minicpmv.py | 137 +++++++++++++++++++------ vllm/model_executor/models/qwen.py | 15 +-- 3 files changed, 126 insertions(+), 41 deletions(-) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 005a24f10aa17..fffd0d4161e10 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -230,8 +230,9 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): else: raise TypeError(f"Invalid image type: {type(image_data)}") - tokenizer = cached_get_tokenizer(model_config.tokenizer, - trust_remote_code=True) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) prompt = llm_inputs.get("prompt") prompt_token_ids = llm_inputs["prompt_token_ids"] @@ -278,8 +279,9 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): use_thumbnail=use_thumbnail) for img in data ] model_config = ctx.model_config - tokenizer = cached_get_tokenizer(model_config.tokenizer, - trust_remote_code=True) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) image_token_id = tokenizer.encode(IMG_CONTEXT, add_special_tokens=False, return_tensors="pt")[0] @@ -298,8 +300,9 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int, model_config = ctx.model_config hf_config = ctx.get_hf_config() vision_config = hf_config.vision_config - tokenizer = cached_get_tokenizer(model_config.tokenizer, - trust_remote_code=True) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) seq_data = dummy_seq_data_for_clip( vision_config, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index c0fb6fef78bab..7da7991b4f849 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -33,6 +33,7 @@ from torch import nn from torch.nn.init import trunc_normal_ from transformers import PretrainedConfig +from typing_extensions import NotRequired from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig @@ -52,6 +53,7 @@ from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SequenceData @@ -64,6 +66,17 @@ } +class MiniCPMVImageInput(TypedDict): + """Input mapper input with auxiliary data for computing image bounds.""" + image: Image.Image + + # Image bounds token ids in 0-dim scaler tensor. + im_start_id: torch.Tensor + im_end_id: torch.Tensor + slice_start_id: NotRequired[torch.Tensor] + slice_end_id: NotRequired[torch.Tensor] + + class MiniCPMVImagePixelInputs(TypedDict): pixel_values: List[torch.Tensor] """ @@ -88,8 +101,6 @@ class MiniCPMVImagePixelInputs(TypedDict): """ -MiniCPMVImageInputs = MiniCPMVImagePixelInputs - DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) @@ -234,6 +245,25 @@ def forward(self, x: torch.Tensor, return x +def _build_image_input(ctx: InputContext, + image: Image.Image) -> MiniCPMVImageInput: + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + trust_remote_code=ctx.model_config.trust_remote_code) + if hasattr(tokenizer, "slice_start_id"): + return MiniCPMVImageInput( + image=image, + im_start_id=torch.tensor(tokenizer.im_start_id), + im_end_id=torch.tensor(tokenizer.im_end_id), + slice_start_id=torch.tensor(tokenizer.slice_start_id), + slice_end_id=torch.tensor(tokenizer.slice_end_id)) + else: + return MiniCPMVImageInput(image=image, + im_start_id=torch.tensor( + tokenizer.im_start_id), + im_end_id=torch.tensor(tokenizer.im_end_id)) + + def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: version_float = getattr(config, "version", None) @@ -257,10 +287,13 @@ def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): return SequenceData.from_token_counts((0, seq_len)) -def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int): +def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig, + num_images: int): width = height = hf_config.image_size - image = Image.new("RGB", (width, height), color=0) - return {"image": image if num_images == 1 else [image] * num_images} + image = _build_image_input(ctx, + image=Image.new("RGB", (width, height), + color=0)) + return {"image": [image] if num_images == 1 else [image] * num_images} def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, @@ -269,7 +302,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, num_images = mm_counts["image"] seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images) - mm_data = dummy_image_for_minicpmv(hf_config, num_images) + mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images) return seq_data, mm_data @@ -280,8 +313,9 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): return llm_inputs model_config = ctx.model_config version = get_version_by_config(model_config.hf_config) - tokenizer = cached_get_tokenizer(model_config.tokenizer, - trust_remote_code=True) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) image_processor = cached_get_image_processor(model_config.tokenizer) def get_placeholder(image_size: Tuple[int, int], num_image: int): @@ -317,6 +351,10 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int): new_prompt = "".join(new_prompt_chunks) new_token_ids = tokenizer.encode(new_prompt) + multi_modal_data["image"] = [ + _build_image_input(ctx, image) for image in images + ] + llm_inputs = LLMInputs( prompt_token_ids=new_token_ids, prompt=new_prompt, @@ -325,6 +363,32 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int): return llm_inputs +def input_mapper_for_minicpmv(ctx: InputContext, data: object): + model_config = ctx.model_config + + image_processor = cached_get_image_processor( + model_config.model, trust_remote_code=model_config.trust_remote_code) + if image_processor is None: + raise RuntimeError("No HuggingFace processor is available " + "to process the image object") + + if not isinstance(data, list): + raise ValueError( + "Image input must be list of MiniCPMVImageInput, got (%s)", data) + batch_data = image_processor \ + .preprocess([img["image"] for img in data], return_tensors="pt") \ + .data + + if len(data) > 0: + batch_data["im_start_id"] = data[0]["im_start_id"] + batch_data["im_end_id"] = data[0]["im_end_id"] + if "slice_start_id" in data[0]: + batch_data["slice_start_id"] = data[0]["slice_start_id"] + batch_data["slice_end_id"] = data[0]["slice_end_id"] + + return MultiModalInputs(batch_data) + + class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): """ The abstract class of MiniCPMV can only be inherited, but cannot be @@ -365,7 +429,7 @@ def __init__( def get_embedding( self, input_ids: torch.Tensor, - image_inputs: Optional[MiniCPMVImageInputs], + image_inputs: Optional[MiniCPMVImagePixelInputs], ) -> Tuple[torch.Tensor, torch.Tensor]: vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids) if hasattr(self.config, "scale_emb"): @@ -393,14 +457,20 @@ def get_embedding( return vlm_embedding, vision_hidden_states - def _get_image_bounds(self, input_ids: torch.Tensor) -> torch.Tensor: - tokenizer = cached_get_tokenizer(self.config._name_or_path, - trust_remote_code=True) - start_cond = input_ids == tokenizer.im_start_id - end_cond = input_ids == tokenizer.im_end_id - if hasattr(tokenizer, "slice_start_id"): - start_cond |= (input_ids == tokenizer.slice_start_id) - end_cond |= (input_ids == tokenizer.slice_end_id) + def _get_image_bounds( + self, + input_ids: torch.Tensor, + im_start_id: torch.Tensor, + im_end_id: torch.Tensor, + slice_start_id: Optional[torch.Tensor] = None, + slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor: + # All the images in the batch should share the same special image + # bound token ids. + start_cond = input_ids == im_start_id[0] + end_cond = input_ids == im_end_id[0] + if slice_start_id is not None: + start_cond |= (input_ids == slice_start_id[0]) + end_cond |= (input_ids == slice_end_id[0]) image_start_tokens, = torch.where(start_cond) image_start_tokens += 1 @@ -419,7 +489,7 @@ def _parse_and_validate_inputs( self, input_ids: torch.Tensor, **kwargs: object, - ) -> Optional[MiniCPMVImageInputs]: + ) -> Optional[MiniCPMVImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", []) tgt_sizes = kwargs.pop("tgt_sizes", []) @@ -456,8 +526,17 @@ def _parse_and_validate_inputs( if len(pixel_values_flat) == 0: return None - return MiniCPMVImageInputs( - image_bounds=self._get_image_bounds(input_ids), + im_start_id = kwargs.pop("im_start_id", None) + im_end_id = kwargs.pop("im_end_id", None) + slice_start_id = kwargs.pop("slice_start_id", None) + slice_end_id = kwargs.pop("slice_end_id", None) + if im_start_id is None: + return None + + return MiniCPMVImagePixelInputs( + image_bounds=self._get_image_bounds(input_ids, im_start_id, + im_end_id, slice_start_id, + slice_end_id), pixel_values=pixel_values_flat, tgt_sizes=torch.stack(tgt_sizes_flat), ) @@ -564,8 +643,8 @@ def get_vision_embedding( ) -> torch.Tensor: raise NotImplementedError - def get_vision_hidden_states(self, - data: MiniCPMVImageInputs) -> torch.Tensor: + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: raise NotImplementedError def is_default_weight_loading(self, name: str) -> bool: @@ -654,8 +733,8 @@ def get_vision_embedding( res.append(self.resampler(vision_embedding, tgt_size)) return torch.vstack(res) - def get_vision_hidden_states(self, - data: MiniCPMVImageInputs) -> torch.Tensor: + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: pixel_values = data["pixel_values"] return self.get_vision_embedding(pixel_values) @@ -713,8 +792,8 @@ def get_vision_embedding( vision_embedding = self.resampler(vision_embedding, tgt_sizes) return vision_embedding - def get_vision_hidden_states(self, - data: MiniCPMVImageInputs) -> torch.Tensor: + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] @@ -807,8 +886,8 @@ def get_vision_embedding( ).last_hidden_state return vision_embedding - def get_vision_hidden_states(self, - data: MiniCPMVImageInputs) -> torch.Tensor: + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] @@ -851,7 +930,7 @@ def is_default_weight_loading(self, name: str) -> bool: } -@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_minicpmv) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv) @INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index e62a841485f2d..761c1370b9776 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -674,8 +674,9 @@ def input_processor_for_qwen(ctx: InputContext, prompt = llm_inputs.get("prompt") prompt_token_ids = llm_inputs["prompt_token_ids"] model_config = ctx.model_config - tokenizer = cached_get_tokenizer(model_config.tokenizer, - trust_remote_code=True) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) image_data = multi_modal_data["image"] if isinstance(image_data, torch.Tensor): num_dims = len(image_data.shape) @@ -735,8 +736,9 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: return MultiModalInputs() model_config = ctx.model_config - tokenizer = cached_get_tokenizer(model_config.tokenizer, - trust_remote_code=True) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) image_pair_tok = tokenizer.encode(IMG_START + IMG_END, add_special_tokens=False, @@ -824,8 +826,9 @@ def dummy_data_for_qwen( # We have a visual component - use images to warm up num_images = mm_counts["image"] model_config = ctx.model_config - tokenizer = cached_get_tokenizer(model_config.tokenizer, - trust_remote_code=True) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) # Build the image prompts with no imgpads; the tokenizer will add img pads image_prompt = ''.join( From c23953675f78bc85045d66fa98aea7d0581c2167 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 25 Sep 2024 14:16:11 +0800 Subject: [PATCH 28/50] [Hardware][CPU] Enable mrope and support Qwen2-VL on CPU backend (#8770) --- vllm/model_executor/models/qwen2_vl.py | 16 +++++ vllm/worker/cpu_model_runner.py | 92 +++++++++++++++++++++++--- 2 files changed, 99 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 9f72210c60bf9..889ebc6c2e1ff 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -67,6 +67,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.processor import get_processor +from vllm.utils import is_cpu from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory) @@ -281,6 +282,21 @@ def forward( context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) + elif is_cpu(): + seq_length = q.size(1) + q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]] + attention_mask = torch.zeros([1, seq_length, seq_length], + device=q.device, + dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], + cu_seqlens[i - 1]:cu_seqlens[i]] = True + output = F.scaled_dot_product_attention(q, + k, + v, + attention_mask, + dropout_p=0.0) + context_layer = rearrange(output, "b h s d -> b s h d ") else: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index d7d7d65659b73..cebb0f36a2b28 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -12,11 +12,13 @@ SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SequenceData, + SequenceGroupMetadata) from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, @@ -145,6 +147,38 @@ def build(self) -> ModelInputForCPU: query_lens=seq_lens, ) + def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, + computed_len: int): + mm_kwargs = self.multi_modal_input_mapper(mm_data) + + # special processing for mrope position deltas. + mrope_positions = None + if self.runner.model_is_mrope: + image_grid_thw = mm_kwargs.get("image_grid_thw", None) + video_grid_thw = mm_kwargs.get("video_grid_thw", None) + assert image_grid_thw is not None or video_grid_thw is not None, ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw'.") + + hf_config = self.runner.model_config.hf_config + token_ids = seq_data.get_token_ids() + + mrope_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions( + token_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config. + spatial_merge_size, + context_len=computed_len, + ) + seq_data.mrope_position_delta = mrope_position_delta + return mm_kwargs, mrope_positions + def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -153,6 +187,8 @@ def _prepare_prompt( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + input_mrope_positions: List[List[int]] = [[] for _ in range(3)] + slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] @@ -171,14 +207,20 @@ def _prepare_prompt( seq_lens.append(seq_len) # Prompt token num input_tokens.extend(prompt_tokens) # Token ids + mrope_positions = None + if (mm_data := seq_group_metadata.multi_modal_data): + mm_kwargs, mrope_positions = self._compute_multi_modal_input( + seq_data, mm_data, computed_len) + multi_modal_inputs_list.append(mm_kwargs) + # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, seq_len))) - - if (mm_data := seq_group_metadata.multi_modal_data): - mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) + if mrope_positions: + for idx in range(3): + input_mrope_positions[idx].extend(mrope_positions[idx]) + else: + input_positions.extend(list(range(computed_len, seq_len))) # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] @@ -202,12 +244,18 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + if any(input_mrope_positions): + input_positions = None # type: ignore + else: + input_mrope_positions = None # type: ignore + num_prompt_tokens = len(input_tokens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) # type: ignore - input_positions = torch.tensor(input_positions, + input_positions = torch.tensor(input_positions + or input_mrope_positions, dtype=torch.long, device=self.device) # type: ignore slot_mapping = torch.tensor(slot_mapping, @@ -238,6 +286,7 @@ def _prepare_decode( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + input_mrope_positions: List[List[int]] = [[] for _ in range(3)] slot_mapping: List[int] = [] seq_lens: List[int] = [] block_tables: List[List[int]] = [] @@ -255,7 +304,17 @@ def _prepare_decode( seq_len = seq_data.get_len() position = seq_len - 1 - input_positions.append(position) + if seq_data.mrope_position_delta is not None: + context_len = seq_data.get_num_computed_tokens() + next_pos = MRotaryEmbedding.get_next_input_positions( + seq_data.mrope_position_delta, + context_len, + seq_len, + ) + for idx in range(3): + input_mrope_positions[idx].extend(next_pos[idx]) + else: + input_positions.append(position) seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) @@ -273,12 +332,18 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + if any(input_mrope_positions): + input_positions = None # type: ignore + else: + input_mrope_positions = None # type: ignore + max_decode_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) - input_positions = torch.tensor(input_positions, + input_positions = torch.tensor(input_positions + or input_mrope_positions, dtype=torch.long, device=self.device) slot_mapping = torch.tensor(slot_mapping, @@ -373,6 +438,15 @@ def __init__( raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU']) + @property + def model_is_mrope(self) -> bool: + """Detect if the model has "mrope" rope_scaling type. + mrope requires keep "rope_deltas" between prompt and decoding phases.""" + rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) + if rope_scaling is None: + return False + return rope_scaling.get("type", None) == "mrope" + def load_model(self) -> None: self.model = get_model(model_config=self.model_config, load_config=self.load_config, From 3e073e66f1790f7ce339dad71514983e6e402f30 Mon Sep 17 00:00:00 2001 From: sohamparikh Date: Wed, 25 Sep 2024 02:16:30 -0400 Subject: [PATCH 29/50] [Bugfix] load fc bias from config for eagle (#8790) --- vllm/model_executor/models/eagle.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index ad1ab0231d861..13811d33768a6 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -44,7 +44,7 @@ def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: self.model = model_cls(self.config.model, *args, **kwargs) self.fc = nn.Linear(config.model.hidden_size * 2, config.model.hidden_size, - bias=False) + bias=getattr(self.config, "bias", False)) self.orig_vocab_size = config.vocab_size self.truncated_vocab_size = config.truncated_vocab_size @@ -136,10 +136,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if self.config.truncated_vocab_size < self.config.vocab_size: self.token_map = nn.Parameter(loaded_weight, requires_grad=False) - elif name.startswith("fc."): + elif name.startswith("fc.weight"): weight_loader = getattr(self.fc.weight, "weight_loader", default_weight_loader) weight_loader(self.fc.weight, loaded_weight) + elif name.startswith("fc.bias"): + if self.fc.bias is not None: + weight_loader = getattr(self.fc.bias, "weight_loader", + default_weight_loader) + weight_loader(self.fc.bias, loaded_weight) + else: + raise ValueError("Found bias in the loaded weights " + "but the model config doesn't have bias") elif name.startswith("model.lm_head.") or name.startswith( "model.model."): model_weights[name.split("model.", 1)[-1]] = loaded_weight From 1ac3de09cd87290f7494ce6337623d6edd3f8667 Mon Sep 17 00:00:00 2001 From: Adam Tilghman Date: Wed, 25 Sep 2024 00:49:26 -0700 Subject: [PATCH 30/50] [Frontend] OpenAI server: propagate usage accounting to FastAPI middleware layer (#8672) --- vllm/entrypoints/openai/protocol.py | 5 +++ vllm/entrypoints/openai/serving_chat.py | 26 +++++++++++-- vllm/entrypoints/openai/serving_completion.py | 37 +++++++++++++++---- 3 files changed, 57 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7e9f53b1816d1..40d27f984fbaa 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -107,6 +107,11 @@ class UsageInfo(OpenAIBaseModel): completion_tokens: Optional[int] = 0 +class RequestResponseMetadata(BaseModel): + request_id: str + final_usage_info: Optional[UsageInfo] = None + + class JsonSchemaResponseFormat(OpenAIBaseModel): name: str description: Optional[str] = None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1ee4b3ce17cfa..0321ea98ec742 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -22,7 +22,8 @@ ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, - DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo) + DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata, + ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (BaseModelPath, LoRAModulePath, OpenAIServing, @@ -175,6 +176,11 @@ async def create_chat_completion( "--enable-auto-tool-choice and --tool-call-parser to be set") request_id = f"chat-{random_uuid()}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + try: guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -241,11 +247,13 @@ async def create_chat_completion( # Streaming response if request.stream: return self.chat_completion_stream_generator( - request, result_generator, request_id, conversation, tokenizer) + request, result_generator, request_id, conversation, tokenizer, + request_metadata) try: return await self.chat_completion_full_generator( - request, result_generator, request_id, conversation, tokenizer) + request, result_generator, request_id, conversation, tokenizer, + request_metadata) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -262,6 +270,7 @@ async def chat_completion_stream_generator( request_id: str, conversation: List[ConversationMessage], tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> AsyncGenerator[str, None]: model_name = self.base_model_paths[0].name created_time = int(time.time()) @@ -580,6 +589,13 @@ async def chat_completion_stream_generator( exclude_unset=True, exclude_none=True)) yield f"data: {final_usage_data}\n\n" + # report to FastAPI middleware aggregate usage across all choices + num_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_completion_tokens, + total_tokens=num_prompt_tokens + num_completion_tokens) + except ValueError as e: # TODO: Use a vllm-specific Validation Error logger.error("error in chat completion stream generator: %s", e) @@ -595,6 +611,7 @@ async def chat_completion_full_generator( request_id: str, conversation: List[ConversationMessage], tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.base_model_paths[0].name @@ -714,6 +731,9 @@ async def chat_completion_full_generator( completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, ) + + request_metadata.final_usage_info = usage + response = ChatCompletionResponse( id=request_id, created=created_time, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 9abd74d0561d0..0e8609002e39e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -18,7 +18,9 @@ CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - ErrorResponse, UsageInfo) + ErrorResponse, + RequestResponseMetadata, + UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import (BaseModelPath, LoRAModulePath, @@ -94,6 +96,10 @@ async def create_completion( request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + # Schedule the request and get the result generator. generators: List[AsyncGenerator[RequestOutput, None]] = [] try: @@ -165,13 +171,15 @@ async def create_completion( # Streaming response if stream: - return self.completion_stream_generator(request, - result_generator, - request_id, - created_time, - model_name, - num_prompts=len(prompts), - tokenizer=tokenizer) + return self.completion_stream_generator( + request, + result_generator, + request_id, + created_time, + model_name, + num_prompts=len(prompts), + tokenizer=tokenizer, + request_metadata=request_metadata) # Non-streaming response final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) @@ -198,6 +206,7 @@ async def create_completion( created_time, model_name, tokenizer, + request_metadata, ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") @@ -227,6 +236,7 @@ async def completion_stream_generator( model_name: str, num_prompts: int, tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> AsyncGenerator[str, None]: num_choices = 1 if request.n is None else request.n previous_text_lens = [0] * num_choices * num_prompts @@ -346,6 +356,14 @@ async def completion_stream_generator( exclude_unset=False, exclude_none=True)) yield f"data: {final_usage_data}\n\n" + # report to FastAPI middleware aggregate usage across all choices + total_prompt_tokens = sum(num_prompt_tokens) + total_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens) + except ValueError as e: # TODO: Use a vllm-specific Validation Error data = self.create_streaming_error_response(str(e)) @@ -360,6 +378,7 @@ def request_output_to_completion_response( created_time: int, model_name: str, tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> CompletionResponse: choices: List[CompletionResponseChoice] = [] num_prompt_tokens = 0 @@ -433,6 +452,8 @@ def request_output_to_completion_response( total_tokens=num_prompt_tokens + num_generated_tokens, ) + request_metadata.final_usage_info = usage + return CompletionResponse( id=request_id, created=created_time, From 3368c3ab36436af1342a3156971412e9efdb6419 Mon Sep 17 00:00:00 2001 From: David Newman Date: Wed, 25 Sep 2024 17:52:26 +1000 Subject: [PATCH 31/50] [Bugfix] Ray 2.9.x doesn't expose available_resources_per_node (#8767) Signed-off-by: darthhexx --- vllm/executor/ray_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 59e9854393b6b..7e46acefc5b0e 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -18,9 +18,14 @@ try: import ray - from ray._private.state import available_resources_per_node from ray.util import placement_group_table from ray.util.placement_group import PlacementGroup + try: + from ray._private.state import available_resources_per_node + except ImportError: + # Ray 2.9.x doesn't expose `available_resources_per_node` + from ray._private.state import state as _state + available_resources_per_node = _state._available_resources_per_node class RayWorkerWrapper(WorkerWrapperBase): """Ray wrapper for vllm.worker.Worker, allowing Worker to be From 8fae5ed7f6bfd63b81310fcb24b310d9205c9687 Mon Sep 17 00:00:00 2001 From: Woo-Yeon Lee Date: Wed, 25 Sep 2024 16:53:03 +0900 Subject: [PATCH 32/50] [Misc] Fix minor typo in scheduler (#8765) --- vllm/core/scheduler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index b707d87c3af83..873decff37c1e 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1554,14 +1554,14 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, # the number of new tokens that is dividable by the block size # to avoid partial block matching. block_size = self.cache_config.block_size - reminder = budget.token_budget % block_size - if reminder != 0: + remainder = budget.token_budget % block_size + if remainder != 0: raise ValueError("When enabling chunked prefill and " "prefix caching, max_num_batched_tokens " "(chunk size) must be dividable by " "block size, but got chunk_size " f"({budget.token_budget}) % block_size " - f"({block_size}) = {reminder}") + f"({block_size}) = {remainder}") if remaining_token_budget < num_new_tokens: num_new_tokens = (remaining_token_budget // block_size) * block_size From 1c046447a6d1ac3c99b9f453796f0d355d673deb Mon Sep 17 00:00:00 2001 From: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Date: Wed, 25 Sep 2024 10:26:37 -0400 Subject: [PATCH 33/50] [CI/Build][Bugfix][Doc][ROCm] CI fix and doc update after ROCm 6.2 upgrade (#8777) --- .buildkite/test-pipeline.yaml | 5 ++++- Dockerfile.rocm | 2 +- docs/source/getting_started/amd-installation.rst | 12 +++++++++++- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 379a67c4c8cf8..54dd87bfa2a10 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -90,8 +90,11 @@ steps: commands: - pip install -e ./plugins/vllm_add_dummy_model - pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api] - - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process + - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process + - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process + - pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - pytest -v -s entrypoints/openai - pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 9aa3a974e7046..496e6bed7c022 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -120,7 +120,7 @@ COPY . . # Package upgrades for useful functionality or to avoid dependency issues RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install --upgrade numba scipy huggingface-hub[cli] + python3 -m pip install --upgrade numba scipy huggingface-hub[cli] pytest-shard # Workaround for ray >= 2.10.0 diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 4ed0bfe70071d..301337aebcf4c 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -28,6 +28,16 @@ Option 1: Build from source with docker (recommended) You can build and install vLLM from source. First, build a docker image from `Dockerfile.rocm `_ and launch a docker container from the image. +It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: + +.. code-block:: console + + { + "features": { + "buildkit": true + } + } + `Dockerfile.rocm `_ uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches. It provides flexibility to customize the build of docker image using the following arguments: @@ -152,7 +162,7 @@ Note to get your gfx architecture, run `rocminfo |grep gfx`. $ python3 setup.py develop - This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation + This may take 5-10 minutes. Currently, :code:`pip install .` does not work for ROCm installation. .. tip:: From 300da09177477d0a4d2b55790addefd971f52ae0 Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Wed, 25 Sep 2024 10:35:52 -0400 Subject: [PATCH 34/50] [Kernel] Fullgraph and opcheck tests (#8479) --- .buildkite/test-pipeline.yaml | 19 +++- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 2 +- csrc/torch_bindings.cpp | 4 +- tests/compile/test_full_graph.py | 45 ++------ tests/compile/test_full_graph_multi_gpu.py | 22 ++++ tests/compile/test_full_graph_smoke.py | 13 +++ tests/compile/utils.py | 104 ++++++++++++++++++ tests/conftest.py | 6 + tests/kernels/test_aqlm.py | 37 +++++++ tests/kernels/test_attention.py | 9 +- tests/kernels/test_awq.py | 38 +++++++ tests/kernels/test_causal_conv1d.py | 74 ++++++++++++- tests/kernels/test_cutlass.py | 10 ++ tests/kernels/test_flash_attn.py | 61 +++++----- tests/kernels/test_fp8_quant.py | 29 +++++ tests/kernels/test_ggml.py | 22 ++++ tests/kernels/test_gptq.py | 29 +++++ tests/kernels/test_mamba_ssm.py | 66 +++++++++++ tests/kernels/test_marlin_gemm.py | 15 +++ tests/kernels/test_moe.py | 60 +++++++++- tests/kernels/test_rotary_embedding.py | 62 +++++++++++ tests/kernels/test_utils.py | 24 ++++ tests/kernels/utils.py | 43 +++++++- vllm/_custom_ops.py | 61 +++++----- .../layers/mamba/ops/mamba_ssm.py | 4 +- .../layers/quantization/gptq.py | 1 + 26 files changed, 744 insertions(+), 116 deletions(-) create mode 100644 tests/compile/test_full_graph_multi_gpu.py create mode 100644 tests/compile/test_full_graph_smoke.py create mode 100644 tests/compile/utils.py create mode 100644 tests/kernels/test_aqlm.py create mode 100644 tests/kernels/test_awq.py create mode 100644 tests/kernels/test_ggml.py create mode 100644 tests/kernels/test_gptq.py create mode 100644 tests/kernels/test_rotary_embedding.py create mode 100644 tests/kernels/test_utils.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 54dd87bfa2a10..ea8b3d46f1b3f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -70,7 +70,7 @@ steps: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - + - label: Core Test # 10min mirror_hardwares: [amd] fast_check: true @@ -210,6 +210,21 @@ steps: command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py parallelism: 4 +- label: "PyTorch Fullgraph Smoke Test" + fast_check: true + source_file_dependencies: + - vllm/ + - tests/compile + commands: + - pytest -v -s compile/test_full_graph_smoke.py + +- label: "PyTorch Fullgraph Test" + source_file_dependencies: + - vllm/ + - tests/compile + commands: + - pytest -v -s compile/test_full_graph.py + - label: Kernels Test %N # 30min each mirror_hardwares: [amd] source_file_dependencies: @@ -355,7 +370,7 @@ steps: - tests/distributed/ - vllm/compilation commands: - - pytest -v -s ./compile/test_full_graph.py + - pytest -v -s ./compile/test_full_graph_multi_gpu.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index df968dda92adc..d7829f5d583d4 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -586,7 +586,7 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { selective_scan_fwd_cuda(params, stream); }); - std::vector result = {out, x.value()}; + std::vector result = {out}; if (has_z) { result.push_back(out_z); } return result; } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4b374af5ae24e..b6ba1b2a26e10 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -275,7 +275,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor! A, Tensor! B, Tensor! C," "Tensor? D_, Tensor? z_, Tensor? delta_bias_," "bool delta_softplus," - "Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]"); + "Tensor? index_, Tensor!? x) -> Tensor[]"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.def( @@ -292,7 +292,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? bias_," "Tensor? seq_idx_," "Tensor? initial_states_," - "Tensor? final_states_out_," + "Tensor!? final_states_out_," "bool silu_activation) -> Tensor"); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); #endif diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 2e309aaa58d48..5dd65ad7236f9 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,42 +1,13 @@ -import os - import pytest -from vllm.utils import cuda_device_count_stateless - -from ..utils import fork_new_process_for_each_test - - -@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) -@pytest.mark.parametrize("tp_size", [1, 2]) -@fork_new_process_for_each_test -def test_full_graph(model, tp_size): - - # Skip the test if there are not enough CUDA devices. - if cuda_device_count_stateless() < tp_size: - pytest.skip("Not enough CUDA devices for the test.") - - # make sure these models can be captured in full graph mode - if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: - os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" +from vllm.compilation.backends import vllm_backend - from vllm import LLM, SamplingParams - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - sampling_params = SamplingParams(temperature=0) - llm = LLM(model=model, - enforce_eager=True, - tensor_parallel_size=tp_size, - disable_custom_all_reduce=True) +from .utils import TEST_MODELS, check_full_graph_support - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +@pytest.mark.parametrize("model_info", TEST_MODELS) +@pytest.mark.parametrize("backend", ["eager", vllm_backend]) +def test_full_graph(model_info, backend): + model = model_info[0] + model_kwargs = model_info[1] + check_full_graph_support(model, model_kwargs, backend, tp_size=1) diff --git a/tests/compile/test_full_graph_multi_gpu.py b/tests/compile/test_full_graph_multi_gpu.py new file mode 100644 index 0000000000000..e9883d5254e72 --- /dev/null +++ b/tests/compile/test_full_graph_multi_gpu.py @@ -0,0 +1,22 @@ +import pytest + +from vllm.compilation.backends import vllm_backend +from vllm.utils import cuda_device_count_stateless + +from ..utils import fork_new_process_for_each_test +from .utils import TEST_MODELS_SMOKE, check_full_graph_support + + +@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("backend", ["eager", vllm_backend]) +@fork_new_process_for_each_test +def test_full_graph_multi_gpu(model_info, tp_size, backend): + model = model_info[0] + model_kwargs = model_info[1] + + # Skip the test if there are not enough CUDA devices. + if cuda_device_count_stateless() < tp_size: + pytest.skip("Not enough CUDA devices for the test.") + + check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size) diff --git a/tests/compile/test_full_graph_smoke.py b/tests/compile/test_full_graph_smoke.py new file mode 100644 index 0000000000000..0c5a95b4ead4c --- /dev/null +++ b/tests/compile/test_full_graph_smoke.py @@ -0,0 +1,13 @@ +import pytest + +from vllm.compilation.backends import vllm_backend + +from .utils import TEST_MODELS_SMOKE, check_full_graph_support + + +@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) +@pytest.mark.parametrize("backend", ["eager", vllm_backend]) +def test_full_graph(model_info, backend): + model = model_info[0] + model_kwargs = model_info[1] + check_full_graph_support(model, model_kwargs, backend, tp_size=1) diff --git a/tests/compile/utils.py b/tests/compile/utils.py new file mode 100644 index 0000000000000..2d06a0946d911 --- /dev/null +++ b/tests/compile/utils.py @@ -0,0 +1,104 @@ +import os + +import torch + +from tests.quantization.utils import is_quant_method_supported +from vllm import LLM, SamplingParams +from vllm.plugins import set_torch_compile_backend +from vllm.utils import is_hip + +TEST_MODELS_SMOKE = [ + ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", { + "quantization": "compressed-tensors" + }), + ("meta-llama/Meta-Llama-3-8B", {}), +] + +TEST_MODELS = [ + ("facebook/opt-125m", {}), + ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { + "dtype": torch.float16, + "quantization": "compressed-tensors" + }), + ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", { + "dtype": torch.float16, + "quantization": "fp8" + }), + ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", { + "quantization": "compressed-tensors" + }), + ("meta-llama/Meta-Llama-3-8B", {}), +] + +# TODO: enable in pytorch 2.5 +if False and is_quant_method_supported("aqlm"): # noqa: SIM223 + TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", { + "quantization": "aqlm" + })) + +# TODO: enable in pytorch 2.5 +if False and is_quant_method_supported("gguf"): # noqa: SIM223 + TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", { + "quantization": "gguf" + })) + +if is_quant_method_supported("gptq"): + TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", { + "quantization": "gptq" + })) + +if is_quant_method_supported("gptq_marlin"): + TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", { + "quantization": "gptq_marlin" + })) + +if is_quant_method_supported("gptq_marlin_24"): + TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", { + "quantization": "gptq_marlin_24" + })) + +if is_quant_method_supported("marlin"): + TEST_MODELS.append(("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", { + "quantization": "marlin" + })) + +if not is_hip() and is_quant_method_supported("awq"): + TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { + "quantization": "AWQ" + })) + + +def check_full_graph_support(model, model_kwargs, backend, tp_size=1): + # make sure these models can be captured in full graph mode + if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: + os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" + os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" + + # Inductor doesn't support fp8/gptq_marlin_24 yet. + quantization = model_kwargs.get("quantization") + if (quantization == "fp8" or quantization == "gptq_marlin" + or quantization == "gptq_marlin_24") and backend != "eager": + return + + set_torch_compile_backend(backend) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0) + llm = LLM(model=model, + enforce_eager=True, + tensor_parallel_size=tp_size, + disable_custom_all_reduce=True, + **model_kwargs) + + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/conftest.py b/tests/conftest.py index dcd9afdae3c14..354862e3579ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -169,6 +169,12 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): cleanup() +@pytest.fixture(autouse=True) +def dynamo_reset(): + yield + torch._dynamo.reset() + + @pytest.fixture def example_prompts() -> List[str]: prompts = [] diff --git a/tests/kernels/test_aqlm.py b/tests/kernels/test_aqlm.py new file mode 100644 index 0000000000000..860fb66b17354 --- /dev/null +++ b/tests/kernels/test_aqlm.py @@ -0,0 +1,37 @@ +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 + + +def test_aqlm_dequant_opcheck(): + codes = torch.randint(-32768, + 32767, (22016, 512, 1), + device='cuda', + dtype=torch.int16) + codebooks = torch.rand((2, 65536, 1, 8), + device='cuda', + dtype=torch.float16) + codebook_partition_sizes = [11008, 11008] + + opcheck(torch.ops._C.aqlm_dequant, + (codes, codebooks, codebook_partition_sizes)) + + +def test_aqlm_gemm_opcheck(): + input = torch.rand((4, 4096), device='cuda', dtype=torch.float16) + codes = torch.randint(-32768, + 32767, (12288, 512, 1), + device='cuda', + dtype=torch.int16) + codebooks = torch.rand((3, 65536, 1, 8), + device='cuda', + dtype=torch.float16) + scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16) + codebook_partition_sizes = [4096, 4096, 4096] + bias = None + + opcheck(torch.ops._C.aqlm_gemm, + (input, codes, codebooks, scales, codebook_partition_sizes, None)) + opcheck(torch.ops._C.aqlm_gemm, + (input, codes, codebooks, scales, codebook_partition_sizes, bias)) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index ecab512cba16f..52f1ecd176963 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -205,7 +205,8 @@ def test_paged_attention( (output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0])) + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) elif version in ("v2", "rocm"): num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) @@ -246,7 +247,8 @@ def test_paged_attention( key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0])) + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) else: ops.paged_attention_rocm( @@ -274,7 +276,8 @@ def test_paged_attention( key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale), - cond=(head_size == HEAD_SIZES[0])) + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) else: raise AssertionError(f"Unknown version: {version}") diff --git a/tests/kernels/test_awq.py b/tests/kernels/test_awq.py new file mode 100644 index 0000000000000..e421aca48af2c --- /dev/null +++ b/tests/kernels/test_awq.py @@ -0,0 +1,38 @@ +import os + +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 + + +def test_awq_dequantize_opcheck(): + os.environ["VLLM_USE_TRITON_AWQ"] = "0" + qweight = torch.randint(-2000000000, + 2000000000, (8192, 256), + device='cuda', + dtype=torch.int32) + scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16) + zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32) + split_k_iters = 0 + thx = 0 + thy = 0 + opcheck(torch.ops._C.awq_dequantize, + (qweight, scales, zeros, split_k_iters, thx, thy)) + + +def test_awq_gemm_opcheck(): + os.environ["VLLM_USE_TRITON_AWQ"] = "0" + input = torch.rand((2, 8192), device='cuda', dtype=torch.float16) + qweight = torch.randint(-2000000000, + 2000000000, (8192, 256), + device='cuda', + dtype=torch.int32) + scales = torch.randint(-2000000000, + 2000000000, (64, 256), + device='cuda', + dtype=torch.int32) + qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16) + split_k_iters = 8 + opcheck(torch.ops._C.awq_gemm, + (input, qweight, qzeros, scales, split_k_iters)) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 043c4923bd660..744e445fe6673 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -5,6 +5,8 @@ import torch.nn.functional as F from einops import rearrange +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.utils import seed_everything @@ -84,6 +86,64 @@ def causal_conv1d_update_ref(x: torch.Tensor, return (out if activation is None else F.silu(out)).to(dtype=dtype_in) +def causal_conv1d_opcheck_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + seq_idx: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out=None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert (initial_states is + None), "initial_states must be None if seq_idx is not None" + assert (not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and (initial_states.stride(2) != 1 + and initial_states.stride(1) != 1): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert (final_states_out.stride(2) == 1 + or final_states_out.stride(1) == 1) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty(batch, + width - 1, + dim, + device=x.device, + dtype=x.dtype).transpose(1, 2) + else: + final_states_out = None + + opcheck(torch.ops._C.causal_conv1d_fwd, + (x, weight, bias, seq_idx, initial_states, final_states_out, + activation in ["silu", "swish"])) + + @pytest.mark.parametrize("return_final_states", [False, True]) @pytest.mark.parametrize("has_initial_states", [False, True]) @pytest.mark.parametrize("channel_last", [False, True]) @@ -149,6 +209,14 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation) + + causal_conv1d_opcheck_fn(x_ref, + weight_ref, + bias_ref, + initial_states=initial_states_ref, + return_final_states=return_final_states, + activation=activation) + if return_final_states: assert final_states is not None and final_states_ref is not None assert torch.allclose(final_states, @@ -205,6 +273,10 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + opcheck( + torch.ops._C.causal_conv1d_update, + (x, conv_state, weight, bias, activation in ["silu", "swish"], None)) + @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @@ -258,7 +330,5 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, bias, activation=activation) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index cc4ca2e91e76f..993e67e827ea0 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -15,6 +15,9 @@ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +capability = current_platform.get_device_capability() +capability = capability[0] * 10 + capability[1] + def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) @@ -74,6 +77,9 @@ def cutlass_fp8_gemm_helper(m: int, torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2) + opcheck(torch.ops._C.cutlass_scaled_mm, + (out, a, b, scale_a, scale_b, bias)) + def cutlass_int8_gemm_helper(m: int, n: int, @@ -425,3 +431,7 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): baseline = torch.mm(scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) + + +def test_cutlass_support_opcheck(): + opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 8e960d098c408..71f61c19dd951 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -4,6 +4,7 @@ import torch import vllm.attention.backends.flash_attn # noqa: F401 +from tests.kernels.utils import opcheck from vllm.utils import seed_everything NUM_HEADS = [(4, 4), (8, 2), (16, 2)] @@ -127,19 +128,19 @@ def test_flash_attn_with_paged_kv( else: test_utils = ["test_faketensor"] - torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache, - args=tuple(), - kwargs=dict( - decode_query=query.unsqueeze(1), - key_cache=key_cache, - value_cache=value_cache, - softmax_scale=scale, - causal=True, - block_table=block_tables, - cache_seqlens=kv_lens_tensor, - softcap=soft_cap if soft_cap is not None else 0, - ), - test_utils=test_utils) + opcheck(torch.ops.vllm.flash_attn_with_kvcache, + args=tuple(), + kwargs=dict( + decode_query=query.unsqueeze(1), + key_cache=key_cache, + value_cache=value_cache, + softmax_scale=scale, + causal=True, + block_table=block_tables, + cache_seqlens=kv_lens_tensor, + softcap=soft_cap if soft_cap is not None else 0, + ), + test_utils=test_utils) ref_output = ref_paged_attn( query=query, @@ -232,23 +233,23 @@ def test_varlen_with_paged_kv( else: test_utils = ["test_faketensor"] - torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func, - args=tuple(), - kwargs=dict( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_query_lens, - cu_seqlens_k=cu_kv_lens, - max_seqlen_q=max_query_len, - max_seqlen_k=max_kv_len, - softmax_scale=scale, - causal=True, - window_size=window_size, - block_table=block_tables, - softcap=soft_cap if soft_cap is not None else 0, - ), - test_utils=test_utils) + opcheck(torch.ops.vllm.flash_attn_varlen_func, + args=tuple(), + kwargs=dict( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=block_tables, + softcap=soft_cap if soft_cap is not None else 0, + ), + test_utils=test_utils) ref_output = ref_paged_attn( query=query, diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index 49f5ce53aab54..c18f5f468dc5a 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -5,6 +5,7 @@ from tests.kernels.quant_utils import (FP8_DTYPE, ref_dynamic_per_tensor_fp8_quant, ref_dynamic_per_token_quant) +from tests.kernels.utils import opcheck from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -16,6 +17,26 @@ SEEDS = [0] +def opcheck_fp8_quant(output, + input, + scale=None, + scale_ub=None, + use_per_token_if_dynamic=False): + if scale is not None: + opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale)) + elif use_per_token_if_dynamic: + scale = torch.empty((input.shape[0], 1), + device=input.device, + dtype=torch.float32) + opcheck(torch.ops._C.dynamic_per_token_scaled_fp8_quant, + (output, input, scale, scale_ub)) + else: + scale = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + opcheck(torch.ops._C.dynamic_scaled_fp8_quant, (output, input, scale)) + + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @@ -41,6 +62,12 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, torch.testing.assert_close(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)) + opcheck_fp8_quant(ops_out, + x, + None, + scale_ub, + use_per_token_if_dynamic=True) + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @@ -60,6 +87,8 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, torch.testing.assert_close(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)) + opcheck_fp8_quant(ops_out, x) + # Regression test for a case with large activations where an int32 index cannot # represent the number of elements. diff --git a/tests/kernels/test_ggml.py b/tests/kernels/test_ggml.py new file mode 100644 index 0000000000000..dddb285bf26ec --- /dev/null +++ b/tests/kernels/test_ggml.py @@ -0,0 +1,22 @@ +import gguf +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 + + +@pytest.mark.parametrize("quant_type", [12]) +def test_ggml_opcheck(quant_type): + block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type] + shape = [256, 1152] + qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8) + m = qweight.shape[0] + n = qweight.shape[1] // type_size * block_size + opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n)) + + x = torch.rand((m, 512), device='cuda', dtype=torch.float16) + opcheck(torch.ops._C.ggml_mul_mat_a8, + (qweight, x, quant_type, qweight.shape[0])) + opcheck(torch.ops._C.ggml_mul_mat_vec_a8, + (qweight, x, quant_type, qweight.shape[0])) diff --git a/tests/kernels/test_gptq.py b/tests/kernels/test_gptq.py new file mode 100644 index 0000000000000..c1ca6f1f5191b --- /dev/null +++ b/tests/kernels/test_gptq.py @@ -0,0 +1,29 @@ +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 + + +def test_gptq_shuffle_opcheck(): + weight = torch.randint(-2000000, + 2000000, (1792, 4096), + device='cuda', + dtype=torch.int32) + perm = torch.empty((0, ), device='cuda', dtype=torch.int32) + bit = 4 + opcheck(torch.ops._C.gptq_shuffle, (weight, perm, bit)) + + +def test_gptq_gemm_opcheck(): + a = torch.rand((240, 4096), device='cuda', dtype=torch.float16) + weight = torch.randint(-2000000, + 2000000, (512, 6144), + device='cuda', + dtype=torch.int32) + zeros = torch.zeros((32, 768), device='cuda', dtype=torch.int32) + scales = torch.rand((32, 6144), device='cuda', dtype=torch.float16) + idx = torch.empty((0, ), device='cuda', dtype=torch.int32) + use_exllama = True + bit = 4 + opcheck(torch.ops._C.gptq_gemm, + (a, weight, zeros, scales, idx, use_exllama, bit)) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index 366475222a68e..5a6149562e886 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -3,6 +3,8 @@ import torch.nn.functional as F from einops import rearrange, repeat +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.utils import seed_everything @@ -161,6 +163,59 @@ def selective_scan_ref(u, return out if not return_last_state else (out, last_state) +def selective_scan_opcheck_fn(u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + position_indices=None, + prev_state=None): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). + """ + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = B.unsqueeze(1) + if C.dim() == 3: + C = C.unsqueeze(1) + n_chunks = int((u.shape[-1] + 2048 - 1) / 2048) + x = torch.zeros(( + u.shape[0], + u.shape[1], + n_chunks, + int(A.shape[1] * 2), + ), + device=u.device, + dtype=torch.float32, + requires_grad=False) + x[:, :, 0, 0::2] = 1 + if prev_state is not None: + x[:, :, 0, 1::2].copy_(prev_state) + + # Disable test_autograd_registration for now as it seems to trigger + # a bogus error. + opcheck(torch.ops._C.selective_scan_fwd, + (u, delta, A, B, C, D, z, delta_bias, delta_softplus, + position_indices, x), + test_utils=["test_schema", "test_faketensor"]) + + @pytest.mark.parametrize('wtype', [torch.float32]) @pytest.mark.parametrize('itype', [torch.float32]) @pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) @@ -274,6 +329,17 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, assert state is not None and state_ref is not None assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + selective_scan_opcheck_fn(u, + delta, + A, + B, + C, + D, + z=z, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + return_last_state=return_last_state) + @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index 721d3a6a819ac..a9bb72156c39e 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -501,3 +501,18 @@ def test_marlin_qqq_gemm( max_diff = compute_max_diff(output, output_ref) assert max_diff < 0.04 + + +def test_marlin_gemm_opcheck(): + size_m = 2048 + size_n = 4096 + size_k = 4096 + a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16) + w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32) + s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16) + wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL).scratch + x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) + y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) + torch.testing.assert_close(x, y) + opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k)) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index b1f0516dfa0b3..c6ddcc8ce79f5 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -9,11 +9,14 @@ from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe, single_marlin_moe) -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE @@ -247,6 +250,35 @@ def test_fused_marlin_moe( assert compute_max_diff(marlin_output, triton_output) < 4e-2 + if ops.supports_moe_ops: + token_expert_indicies = torch.empty(m, + topk, + dtype=torch.int32, + device=a.device) + + opcheck(torch.ops._moe_C.topk_softmax, ( + topk_weights, + topk_ids, + token_expert_indicies, + score.float(), + )) + + block_size_m = 4 + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, + e) + + max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + opcheck(torch.ops._moe_C.marlin_gemm_moe, + (a, qweight1, sorted_token_ids, topk_weights, topk_ids, + scales1, g_idx1, sort_indices1, workspace, quant_type, m, + 2 * n, k, True, e, topk, block_size_m, True, False)) + @pytest.mark.skip("This test is here for the sake of debugging, " "don't run it in automated tests.") @@ -319,3 +351,29 @@ def test_single_marlin_moe_multiply( torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2 + + +def test_moe_align_block_size_opcheck(): + num_experts = 4 + block_size = 4 + topk_ids = torch.randint(0, + num_experts, (3, 4), + dtype=torch.int32, + device='cuda') + + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + + opcheck(torch.ops._C.moe_align_block_size, + (topk_ids, num_experts, block_size, sorted_ids, expert_ids, + num_tokens_post_pad)) diff --git a/tests/kernels/test_rotary_embedding.py b/tests/kernels/test_rotary_embedding.py new file mode 100644 index 0000000000000..da879406b3936 --- /dev/null +++ b/tests/kernels/test_rotary_embedding.py @@ -0,0 +1,62 @@ +""" +Tests for miscellaneous utilities +""" + +from typing import Optional + +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding + + +def rotary_embedding_opcheck(rot, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None): + cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype) + + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + opcheck(torch.ops._C.batched_rotary_embedding, + (positions, query, key, rot.head_size, cos_sin_cache, + rot.is_neox_style, rot.rotary_dim, offsets)) + else: + opcheck(torch.ops._C.rotary_embedding, + (positions, query, key, rot.head_size, cos_sin_cache, + rot.is_neox_style)) + + +@pytest.mark.parametrize("device", ["cuda"]) +@pytest.mark.parametrize("max_position", [11, 4096, 32768]) +@pytest.mark.parametrize("is_neox_style", [True, False]) +@pytest.mark.parametrize("rotary_dim", [32]) +@pytest.mark.parametrize("head_size", [32, 108]) +@pytest.mark.parametrize("seq_len", [11, 1024]) +def test_rotary_embedding_opcheck(dist_init, device, max_position, + is_neox_style, rotary_dim, head_size, + seq_len): + batch_size = 1 + base = 0 + num_heads = 7 + rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, torch.float32) + + positions = torch.randint(0, + max_position, (batch_size, seq_len), + device=device) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=torch.float32, + device=device) + key = torch.randn_like(query) + + rotary_embedding_opcheck(rot, positions, query, key) + offsets = torch.zeros(batch_size * seq_len, + device=device, + dtype=torch.long) + rotary_embedding_opcheck(rot, positions, query, key, offsets) diff --git a/tests/kernels/test_utils.py b/tests/kernels/test_utils.py new file mode 100644 index 0000000000000..7e5126a76f88b --- /dev/null +++ b/tests/kernels/test_utils.py @@ -0,0 +1,24 @@ +""" +Tests for miscellaneous utilities +""" + +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm.platforms import current_platform + + +def test_convert_fp8_opcheck(): + data = torch.randn((256, 256), dtype=torch.float32, device="cuda") + result = torch.empty_like(data, dtype=torch.float8_e4m3fn) + opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8")) + + +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="Only supported for CUDA") +def test_cuda_utils_opcheck(): + opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0)) + opcheck( + torch.ops._C_cuda_utils. + get_max_shared_memory_per_block_device_attribute, (0, )) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 5746932c30a45..08004efe9e2f8 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -2,12 +2,14 @@ import itertools import random +import unittest from numbers import Number from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union) import pytest import torch +from torch._prims_common import TensorLikeType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, @@ -946,6 +948,34 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters, output_under_test.view_as(ideal_output)) +# Copied/modified from torch._refs.__init__.py +def fp8_allclose( + a: TensorLikeType, + b: TensorLikeType, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> bool: + """ + Reference implementation of torch.allclose + """ + torch._refs._check_close_args(name="torch.allclose", + a=a, + b=b, + rtol=rtol, + atol=atol) + + return bool( + torch.all( + torch.isclose(a.double(), + b.double(), + rtol=rtol, + atol=atol, + equal_nan=equal_nan)).item()) + + +# A special version of op check that has a restricted default set of test_utils +# and a patched version of allclose that supports fp8 types. def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, torch._library.custom_ops.CustomOpDef], args: Tuple[Any, ...], @@ -954,9 +984,10 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, raise_exception: bool = True, cond: bool = True) -> Dict[str, str]: - return torch.library.opcheck( - op, - args, - kwargs, - test_utils=test_utils, - raise_exception=raise_exception) if cond else {} + with unittest.mock.patch('torch.allclose', new=fp8_allclose): + return torch.library.opcheck( + op, + args, + kwargs, + test_utils=test_utils, + raise_exception=raise_exception) if cond else {} diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a71bafc974adf..4d71381184de5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -20,8 +20,10 @@ if current_platform.is_rocm(): import vllm._rocm_C # noqa: F401 +supports_moe_ops = False with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 + supports_moe_ops = True def hint_on_error(fn): @@ -253,9 +255,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_g_idx, use_exllama, bit) -# TODO: has to be a better way to do this -try: - torch.ops._C.gptq_gemm # noqa B018 +if hasattr(torch.ops._C, "gptq_gemm"): @torch.library.register_fake("_C::gptq_gemm") def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, @@ -265,8 +265,6 @@ def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, return torch.empty((a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device) -except Exception: - pass def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, @@ -292,9 +290,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, size_n, size_k) -# TODO: has to be a better way to do this -try: - torch.ops._C.gptq_marlin_24_gemm # noqa B018 +if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): @torch.library.register_fake("_C::gptq_marlin_24_gemm") def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, @@ -420,8 +416,8 @@ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, @torch.library.register_fake("_C::machete_gemm") def machete_gemm_fake( a: torch.Tensor, - b_q: torch. - Tensor, # Should be the tensor returned by machete_prepack_B + # Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, b_type: ScalarType, b_scales: Optional[torch.Tensor] = None, b_zeros: Optional[torch.Tensor] = None, @@ -451,10 +447,10 @@ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, return torch.empty_like(x) @torch.library.register_fake("_C::causal_conv1d_update") - def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor, - weight: torch.Tensor, - bias_: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: + def causal_conv1d_update_fake( + x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], silu_activation: bool, + conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: return torch.empty_like(x) @torch.library.register_fake("_C::selective_scan_fwd") @@ -465,20 +461,11 @@ def selective_scan_fwd_fake( delta_softplus: bool, index_: Optional[torch.Tensor], x: Optional[torch.Tensor]) -> List[torch.Tensor]: a = torch.empty_like(u) - if x is not None: - b = x - else: - b = torch.empty((u.size(0), u.size(1), A.size(1)), - dtype=u.dtype, - device=u.device) if z_ is not None: c = torch.empty_like(z_) - return [a, b, c] + return [a, c] else: - return [a, b] - -except Exception: - pass + return [a] # cutlass @@ -626,16 +613,12 @@ def machete_prepack_B(b_q_weight: torch.Tensor, return torch.ops._C.machete_prepack_B(b_q_weight, b_type) -# TODO: has to be a better way to do this -try: - torch.ops._C.permute_cols # noqa B018 +if hasattr(torch.ops._C, "permute_cols"): @torch.library.register_fake("_C::permute_cols") def _permute_cols_fake(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.empty_like(a) -except Exception: - pass def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: @@ -828,6 +811,24 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indicies, gating_output) +if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): + + @torch.library.register_fake("_moe_C::marlin_gemm_moe") + def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, + sorted_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, b_scales: torch.Tensor, + g_idx: torch.Tensor, perm: torch.Tensor, + workspace: torch.Tensor, b_q_type: ScalarType, + size_m: int, size_n: int, size_k: int, + is_k_full: bool, num_experts: int, topk: int, + moe_block_size: int, replicate_input: bool, + apply_weights: bool) -> torch.Tensor: + return torch.empty((size_m, topk, size_n), + dtype=a.dtype, + device=a.device) + + def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index a0bed07ac6193..5fe451b2f1318 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -361,8 +361,8 @@ def selective_scan_fn(u, x[:, :, 0, 0::2] = 1 if prev_state is not None: x[:, :, 0, 1::2].copy_(prev_state) - out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, - delta_softplus, position_indices, x) + out, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, + delta_softplus, position_indices, x) last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if z is None: return out if not return_last_state else (out, last_state) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index c067a76405df6..1cfadb4f42ca8 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -217,6 +217,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) layer.qweight = Parameter(layer.qweight.data, requires_grad=False) layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) + layer.scales = Parameter(layer.scales.data, requires_grad=False) # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass From c6f2485c823b5cd76cca70798e653c6eadb811de Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 26 Sep 2024 00:35:23 +0800 Subject: [PATCH 35/50] [[Misc]] Add extra deps for openai server image (#8792) --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index ec803764a128d..6bb4bd032c39c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -202,7 +202,7 @@ FROM vllm-base AS vllm-openai # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' + pip install accelerate hf_transfer 'modelscope!=1.15.0' bitsandbytes>=0.44.0 timm==0.9.10 ENV VLLM_USAGE_SOURCE production-docker-image From 0c4d2ad5e641de145682674066a84ffc632e714e Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Thu, 26 Sep 2024 00:35:53 +0800 Subject: [PATCH 36/50] [VLM][Bugfix] internvl with num_scheduler_steps > 1 (#8614) --- vllm/model_executor/models/internvl.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index fffd0d4161e10..b1748700d481a 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -19,7 +19,7 @@ from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -379,6 +379,11 @@ def __init__(self, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) + if hasattr(self.language_model, "sampler"): + self.sampler = self.language_model.sampler + else: + self.sampler = Sampler() + def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale From 28e1299e60e565a56a2db41396380f74b8d29e57 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 26 Sep 2024 00:36:47 +0800 Subject: [PATCH 37/50] rename PromptInputs and inputs with backward compatibility (#8760) --- benchmarks/benchmark_latency.py | 8 +- .../dev/multimodal/multimodal_index.rst | 2 +- .../dev/offline_inference/llm_inputs.rst | 2 +- docs/source/models/vlm.rst | 2 +- tests/async_engine/test_async_llm_engine.py | 8 +- tests/entrypoints/llm/test_encode.py | 34 ------ tests/entrypoints/llm/test_generate.py | 37 ------ tests/mq_llm_engine/test_error_handling.py | 12 +- tests/mq_llm_engine/utils.py | 2 +- vllm/__init__.py | 4 +- vllm/engine/async_llm_engine.py | 110 +++++++++++++++--- vllm/engine/llm_engine.py | 52 +++++++-- vllm/engine/multiprocessing/__init__.py | 61 +++++++++- vllm/engine/multiprocessing/client.py | 95 ++++++++++++--- vllm/engine/multiprocessing/engine.py | 2 +- vllm/engine/protocol.py | 8 +- vllm/entrypoints/llm.py | 68 +++++------ vllm/inputs/__init__.py | 20 +++- vllm/inputs/data.py | 48 +++++--- vllm/inputs/parse.py | 22 ++-- vllm/inputs/preprocess.py | 86 +++++++------- 21 files changed, 438 insertions(+), 245 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index a39d1cf842f06..eadf994cacd34 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -61,7 +61,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_inputs: List[PromptInputs] = [{ + dummy_prompts: List[PromptType] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] @@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index 241b2ccd0991e..e112b43aade5e 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -8,7 +8,7 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` -via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`. +via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`. Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities by following :ref:`this guide `. diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst index 9adf82d43f3e0..0d47281db485e 100644 --- a/docs/source/dev/offline_inference/llm_inputs.rst +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -1,7 +1,7 @@ LLM Inputs ========== -.. autodata:: vllm.inputs.PromptInputs +.. autodata:: vllm.inputs.PromptType .. autoclass:: vllm.inputs.TextPrompt :show-inheritance: diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 08db891665044..ca5b125369c85 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model. -To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: +To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`: * ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 6cae76f74603d..1903a7582dc89 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -86,17 +86,19 @@ class MockAsyncLLMEngine(AsyncLLMEngine): @pytest.mark.asyncio async def test_new_requests_event(): + params = SamplingParams() + engine = MockAsyncLLMEngine() engine.start_background_loop() await asyncio.sleep(0.01) assert engine.engine.step_calls == 0 - await engine.add_request("1", "", None) + await engine.add_request("1", "", params) await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 1 assert engine.engine.step_calls == 1 - await engine.add_request("2", "", None) + await engine.add_request("2", "", params) engine.engine.generate("2") await asyncio.sleep(0) await asyncio.sleep(0) @@ -111,7 +113,7 @@ async def test_new_requests_event(): await asyncio.sleep(0.001) assert engine.engine.step_calls == old_step_calls - await engine.add_request("3", "", None) + await engine.add_request("3", "", params) await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 3 assert engine.engine.step_calls == old_step_calls + 1 diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index d1056a0490509..1885f2e168d80 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -49,21 +49,6 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput], assert [o.outputs for o in o1] == [o.outputs for o in o2] -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt', PROMPTS) -def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params) - - v2_output = llm.encode(prompt, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup @pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, @@ -79,25 +64,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, assert_outputs_equal(v1_output, v2_output) -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params) - - v2_output = llm.encode(PROMPTS, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.encode( - [{ - "prompt": p - } for p in PROMPTS], - pooling_params=pooling_params, - ) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): pooling_params = PoolingParams() diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index cd989225e2483..6543c4bb1b58e 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -47,23 +47,6 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): assert [o.outputs for o in o1] == [o.outputs for o in o2] -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt', PROMPTS) -def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.generate(prompts=prompt, - sampling_params=sampling_params) - - v2_output = llm.generate(prompt, sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.generate({"prompt": prompt}, - sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup @pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, @@ -79,26 +62,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, assert_outputs_equal(v1_output, v2_output) -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.generate(prompts=PROMPTS, - sampling_params=sampling_params) - - v2_output = llm.generate(PROMPTS, sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.generate( - [{ - "prompt": p - } for p in PROMPTS], - sampling_params=sampling_params, - ) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): sampling_params = SamplingParams(temperature=0.0, top_p=1.0) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 76b2f494d5b25..616a15a1328de 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket): # Throws an error in first forward pass. with pytest.raises(RAISED_ERROR): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket): # Engine is errored, should get ENGINE_DEAD_ERROR. with pytest.raises(MQEngineDeadError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket): # Generate call should throw ENGINE_DEAD_ERROR with pytest.raises(MQEngineDeadError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -160,7 +160,7 @@ async def test_failed_abort(tmp_socket): # with reference to the original KeyError("foo") with pytest.raises(MQEngineDeadError) as execinfo: async for _ in client.generate( - inputs="Hello my name is", + prompt="Hello my name is", sampling_params=SamplingParams(max_tokens=10), request_id=uuid.uuid4()): pass @@ -183,7 +183,7 @@ async def test_bad_request(tmp_socket): # Invalid request should fail, but not crash the server. with pytest.raises(ValueError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-1", lora_request=LoRARequest( @@ -192,7 +192,7 @@ async def test_bad_request(tmp_socket): pass # This request should be okay. - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-2"): pass diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index e27fd77923412..3ffa126070ca0 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -20,7 +20,7 @@ async def generate( count = 0 async for out in client.generate( request_id=request_id, - inputs="Hello my name is Robert and", + prompt="Hello my name is Robert and", sampling_params=SamplingParams(max_tokens=num_tokens, temperature=0)): diff --git a/vllm/__init__.py b/vllm/__init__.py index 90363b3e49b73..8f477ea84756d 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,7 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) @@ -19,7 +19,7 @@ "__version_tuple__", "LLM", "ModelRegistry", - "PromptInputs", + "PromptType", "TextPrompt", "TokensPrompt", "SamplingParams", diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 34e7e05341f02..54c5af2fe3665 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -2,8 +2,8 @@ import time import weakref from functools import partial -from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, - Mapping, Optional, Set, Tuple, Type, Union) +from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable, + List, Mapping, Optional, Set, Tuple, Type, Union, overload) from weakref import ReferenceType import vllm.envs as envs @@ -17,7 +17,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -28,7 +28,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import weak_bind +from vllm.utils import deprecate_kwargs, weak_bind logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -402,17 +402,54 @@ async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() + @overload # DEPRECATED async def add_request_async( self, request_id: str, - inputs: PromptInputs, + *, + inputs: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @overload + async def add_request_async( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + async def add_request_async( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: """Async version of :meth:`add_request`.""" + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") @@ -420,7 +457,7 @@ async def add_request_async( arrival_time = time.time() preprocessed_inputs = await self.input_preprocessor.preprocess_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -774,16 +811,55 @@ async def run_engine_loop(engine_ref: ReferenceType): # This method does not need to be async, but kept that way # for backwards compatibility. - async def add_request( + @overload # DEPRECATED + def add_request( self, request_id: str, - inputs: PromptInputs, + *, + inputs: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Coroutine[None, None, AsyncGenerator[Union[ + RequestOutput, EmbeddingRequestOutput], None]]: + ... + + @overload + def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Coroutine[None, None, AsyncGenerator[Union[ + RequestOutput, EmbeddingRequestOutput], None]]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + async def add_request( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None, # DEPRECATED ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + if not self.is_running: if self.start_engine_loop: self.start_background_loop() @@ -797,7 +873,7 @@ async def add_request( stream = self._request_tracker.add_request( request_id, verbose=self.log_requests, - inputs=inputs, + prompt=prompt, params=params, arrival_time=arrival_time or time.time(), lora_request=lora_request, @@ -808,7 +884,7 @@ async def add_request( async def generate( self, - inputs: PromptInputs, + prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -822,8 +898,7 @@ async def generate( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -881,7 +956,7 @@ async def generate( """ async for output in await self.add_request( request_id, - inputs, + prompt, sampling_params, lora_request=lora_request, trace_headers=trace_headers, @@ -891,7 +966,7 @@ async def generate( async def encode( self, - inputs: PromptInputs, + prompt: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -904,8 +979,7 @@ async def encode( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -959,7 +1033,7 @@ async def encode( """ async for output in await self.add_request( request_id, - inputs, + prompt, pooling_params, lora_request=lora_request, trace_headers=trace_headers, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c341b236003a3..7266d8e18a8ab 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,7 +6,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence -from typing import Set, Type, Union +from typing import Set, Type, Union, overload import torch from typing_extensions import TypeVar @@ -29,7 +29,7 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptInputs) + InputRegistry, LLMInputs, PromptType) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -51,7 +51,7 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter, Device, weak_bind +from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -689,16 +689,51 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() + @overload # DEPRECATED def add_request( self, request_id: str, - inputs: PromptInputs, + *, + inputs: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + ) -> None: + ... + + @overload + def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def add_request( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: """Add a request to the engine's request pool. @@ -708,8 +743,7 @@ def add_request( Args: request_id: The unique ID of the request. - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. params: Parameters for sampling or pooling. :class:`~vllm.SamplingParams` for text generation. @@ -744,6 +778,10 @@ def add_request( >>> # continue the request processing >>> ... """ + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") @@ -756,7 +794,7 @@ def add_request( arrival_time = time.time() preprocessed_inputs = self.input_preprocessor.preprocess( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 165e6cc2146c3..05067a6a192d5 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -1,13 +1,14 @@ from dataclasses import dataclass from enum import Enum -from typing import List, Mapping, Optional, Union +from typing import List, Mapping, Optional, Union, overload from vllm import PoolingParams -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +from vllm.utils import deprecate_kwargs VLLM_RPC_SUCCESS_STR = "SUCCESS" @@ -23,13 +24,67 @@ class MQEngineDeadError(RuntimeError): @dataclass class RPCProcessRequest: - inputs: PromptInputs + prompt: PromptType params: Union[SamplingParams, PoolingParams] request_id: str lora_request: Optional[LoRARequest] = None trace_headers: Optional[Mapping[str, str]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None + @overload # DEPRECATED + def __init__( + self, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @overload + def __init__( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def __init__( + self, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None, # DEPRECATED + ) -> None: + if inputs is not None: + prompt = inputs + assert (prompt is not None and params is not None + and request_id is not None) + + super().__init__() + + self.prompt = prompt + self.params = params + self.request_id = request_id + self.lora_request = lora_request + self.trace_headers = trace_headers + self.prompt_adapter_request = prompt_adapter_request + @dataclass class RPCError: diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 7e397cf408fba..239ca52ef13e2 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -3,7 +3,7 @@ import pickle from contextlib import contextmanager, suppress from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, - Union) + Union, overload) import cloudpickle import zmq @@ -24,13 +24,14 @@ RPCStartupRequest, RPCStartupResponse) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.utils import deprecate_kwargs logger = init_logger(__name__) @@ -366,14 +367,45 @@ def errored(self) -> bool: def dead_error(self) -> BaseException: return ENGINE_DEAD_ERROR(self._errored_with) + @overload # DEPRECATED def generate( self, - inputs: PromptInputs, + *, + inputs: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> AsyncGenerator[RequestOutput, None]: + ... + + @overload + def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> AsyncGenerator[RequestOutput, None]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def generate( + self, + prompt: Optional[PromptType] = None, + sampling_params: Optional[SamplingParams] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None # DEPRECATED ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. @@ -382,8 +414,7 @@ def generate( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -392,17 +423,51 @@ def generate( prompt_adapter_request: Prompt Adapter request to use for generation, if any. """ - return self._process_request(inputs, sampling_params, request_id, + if inputs is not None: + prompt = inputs + assert (prompt is not None and sampling_params is not None + and request_id is not None) + + return self._process_request(prompt, sampling_params, request_id, lora_request, trace_headers, prompt_adapter_request) + @overload # DEPRECATED def encode( self, - inputs: PromptInputs, + *, + inputs: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ... + + @overload + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def encode( + self, + prompt: Optional[PromptType] = None, + pooling_params: Optional[PoolingParams] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + *, + inputs: Optional[PromptType] = None # DEPRECATED ) -> AsyncGenerator[EmbeddingRequestOutput, None]: """Generate outputs for a request from an embedding model. @@ -411,8 +476,7 @@ def encode( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -423,12 +487,17 @@ def encode( The output `EmbeddingRequestOutput` objects from the LLMEngine for the request. """ - return self._process_request(inputs, pooling_params, request_id, + if inputs is not None: + prompt = inputs + assert (prompt is not None and pooling_params is not None + and request_id is not None) + + return self._process_request(prompt, pooling_params, request_id, lora_request, trace_headers) async def _process_request( self, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], request_id: str, lora_request: Optional[LoRARequest] = None, @@ -461,7 +530,7 @@ async def _process_request( request_bytes = pickle.dumps( RPCProcessRequest( - inputs=inputs, + prompt=prompt, params=params, request_id=request_id, lora_request=lora_request, diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index b1dd9915cbbf5..b406d4a759667 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -271,7 +271,7 @@ def _handle_process_request(self, request: RPCProcessRequest): try: self.engine.add_request( request_id=request_id, - inputs=request.inputs, + prompt=request.prompt, params=request.params, lora_request=request.lora_request, trace_headers=request.trace_headers, diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 70444faa670a2..d0bbeb357b506 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -3,7 +3,7 @@ from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.inputs.data import PromptInputs +from vllm.inputs.data import PromptType from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -35,19 +35,19 @@ def dead_error(self) -> BaseException: def generate( self, - inputs: PromptInputs, + prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: - """Generates outputs for a request""" + """Generate outputs for a request.""" ... def encode( self, - inputs: PromptInputs, + prompt: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 77ae7b088398a..f4943cb38da44 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -12,7 +12,7 @@ apply_hf_chat_template, apply_mistral_chat_template, parse_chat_messages) -from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -293,8 +293,8 @@ def generate( @overload def generate( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], - /, # We may enable `inputs` keyword after removing the old API + prompts: Union[PromptType, Sequence[PromptType]], + /, *, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -304,14 +304,13 @@ def generate( ... @deprecate_kwargs( - "prompts", "prompt_token_ids", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter instead.", + additional_message="Please use the 'prompts' parameter instead.", ) def generate( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[Union[PromptType, Sequence[PromptType]], Optional[Union[str, List[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -330,7 +329,9 @@ def generate( into a single list and pass it to this method. Args: - inputs: A list of inputs to generate completions for. + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. @@ -358,12 +359,13 @@ def generate( "models (XForCausalLM, XForConditionalGeneration).") if prompt_token_ids is not None: - inputs = self._convert_v1_inputs( + parsed_prompts = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) if isinstance(guided_options_request, dict): if len(guided_options_request) > 1: @@ -378,7 +380,7 @@ def generate( sampling_params = SamplingParams() self._validate_and_add_requests( - inputs=inputs, + prompts=parsed_prompts, params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -648,8 +650,8 @@ def encode( @overload def encode( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], - /, # We may enable `inputs` keyword after removing the old API + prompts: Union[PromptType, Sequence[PromptType]], + /, *, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -659,14 +661,13 @@ def encode( ... @deprecate_kwargs( - "prompts", "prompt_token_ids", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter instead.", + additional_message="Please use the 'prompts' parameter instead.", ) def encode( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[Union[PromptType, Sequence[PromptType]], Optional[Union[str, List[str]]]] = None, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -682,9 +683,9 @@ def encode( into a single list and pass it to this method. Args: - inputs: The inputs to the LLM. You may pass a sequence of inputs for - batch inference. See :class:`~vllm.inputs.PromptInputs` - for more details about the format of each input. + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. @@ -707,19 +708,20 @@ def encode( ) if prompt_token_ids is not None: - inputs = self._convert_v1_inputs( + parsed_prompts = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() self._validate_and_add_requests( - inputs=inputs, + prompts=parsed_prompts, params=pooling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -763,9 +765,9 @@ def _convert_v1_inputs( raise ValueError("Either prompts or prompt_token_ids must be " "provided.") - inputs: List[PromptInputs] = [] + parsed_prompts: List[PromptType] = [] for i in range(num_requests): - item: PromptInputs + item: PromptType if prompts is not None: item = TextPrompt(prompt=prompts[i]) @@ -774,13 +776,13 @@ def _convert_v1_inputs( else: raise AssertionError - inputs.append(item) + parsed_prompts.append(item) - return inputs + return parsed_prompts def _validate_and_add_requests( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[PromptType, Sequence[PromptType]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], @@ -788,11 +790,11 @@ def _validate_and_add_requests( guided_options: Optional[GuidedDecodingRequest] = None, priority: Optional[List[int]] = None, ) -> None: - if isinstance(inputs, (str, dict)): + if isinstance(prompts, (str, dict)): # Convert a single prompt to a list. - inputs = [inputs] + prompts = [prompts] - num_requests = len(inputs) + num_requests = len(prompts) if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") @@ -809,9 +811,9 @@ def _validate_and_add_requests( sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. - for i, request_inputs in enumerate(inputs): + for i, prompt in enumerate(prompts): self._add_request( - request_inputs, + prompt, params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, @@ -821,7 +823,7 @@ def _validate_and_add_requests( def _add_request( self, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -830,7 +832,7 @@ def _add_request( request_id = str(next(self.request_counter)) self.llm_engine.add_request( request_id, - inputs, + prompt, params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 0b08e9691f915..a8c8672cb5fe7 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,5 +1,5 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, + LLMInputs, PromptType, SingletonPrompt, TextPrompt, TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry @@ -16,8 +16,8 @@ __all__ = [ "TextPrompt", "TokensPrompt", - "PromptInputs", - "SingletonPromptInputs", + "PromptType", + "SingletonPrompt", "ExplicitEncoderDecoderPrompt", "LLMInputs", "EncoderDecoderLLMInputs", @@ -28,3 +28,17 @@ "InputContext", "InputRegistry", ] + + +def __getattr__(name: str): + if name == "PromptInput": + import warnings + + msg = ("PromptInput has been renamed to PromptType. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PromptType + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 75ab0c770155b..9e6238cb85ac0 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -33,7 +33,7 @@ class TokensPrompt(TypedDict): """ -SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt] +SingletonPrompt = Union[str, TextPrompt, TokensPrompt] """ Set of possible schemas for a single LLM input: @@ -46,7 +46,7 @@ class TokensPrompt(TypedDict): the user desires to express both the encoder & decoder prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` -A prompt of type :class:`SingletonPromptInputs` may be employed +A prompt of type :class:`SingletonPrompt` may be employed as (1) input to a decoder-only model, (2) input to the encoder of an encoder/decoder model, in the scenario where the decoder-prompt is not specified explicitly, or @@ -55,33 +55,33 @@ class TokensPrompt(TypedDict): """ _T1_co = TypeVar("_T1_co", - bound=SingletonPromptInputs, - default=SingletonPromptInputs, + bound=SingletonPrompt, + default=SingletonPrompt, covariant=True) _T2_co = TypeVar("_T2_co", - bound=SingletonPromptInputs, - default=SingletonPromptInputs, + bound=SingletonPrompt, + default=SingletonPrompt, covariant=True) # TODO: Make fields ReadOnly once mypy supports it class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): - """Represents an encoder/decoder model input prompt, - comprising an explicit encoder prompt and a - decoder prompt. + """ + Represents an encoder/decoder model input prompt, + comprising an explicit encoder prompt and a decoder prompt. The encoder and decoder prompts, respectively, may formatted according to any of the - :class:`SingletonPromptInputs` schemas, and are not + :class:`SingletonPrompt` schemas, and are not required to have the same schema. Only the encoder prompt may have multi-modal data. Note that an :class:`ExplicitEncoderDecoderPrompt` may not be used as an input to a decoder-only model, - and that the `encoder_prompt` and `decoder_prompt` + and that the :code:`encoder_prompt` and :code:`decoder_prompt` fields of this data structure themselves must be - :class:`SingletonPromptInputs` instances. + :class:`SingletonPrompt` instances. """ encoder_prompt: _T1_co @@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): decoder_prompt: Optional[_T2_co] -PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] +PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] """ Set of possible schemas for an LLM input, including both decoder-only and encoder/decoder input types: @@ -140,12 +140,8 @@ class EncoderDecoderLLMInputs(LLMInputs): """ -_T1 = TypeVar("_T1", - bound=SingletonPromptInputs, - default=SingletonPromptInputs) -_T2 = TypeVar("_T2", - bound=SingletonPromptInputs, - default=SingletonPromptInputs) +_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) +_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) def build_explicit_enc_dec_prompt( @@ -176,3 +172,17 @@ def to_enc_dec_tuple_list( return [(enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"]) for enc_dec_prompt in enc_dec_prompts] + + +def __getattr__(name: str): + if name == "PromptInput": + import warnings + + msg = ("PromptInput has been renamed to PromptType. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PromptType + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index ac9d355c64c80..e5fa1e4184277 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -5,7 +5,7 @@ from vllm.utils import is_list_of from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, + LLMInputs, PromptType, SingletonPrompt, TextPrompt, TokensPrompt) @@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict): def parse_singleton_prompt( - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: - if isinstance(inputs, str): - return ParsedStrPrompt(type="str", content=inputs) - elif isinstance(inputs, dict): - if "prompt_token_ids" in inputs: + if isinstance(prompt, str): + return ParsedStrPrompt(type="str", content=prompt) + elif isinstance(prompt, dict): + if "prompt_token_ids" in prompt: return ParsedTokensPrompt(type="tokens", - content=inputs) # type: ignore - elif "prompt" in inputs: - return ParsedTextPrompt(type="text", content=inputs) + content=prompt) # type: ignore + elif "prompt" in prompt: + return ParsedTextPrompt(type="text", content=prompt) raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") def is_explicit_encoder_decoder_prompt( - inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]: - return isinstance(inputs, dict) and "encoder_prompt" in inputs + prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: + return isinstance(prompt, dict) and "encoder_prompt" in prompt def is_valid_encoder_decoder_llm_inputs( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index be2aa5f8cb7d0..1f1b048d37e9b 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -9,8 +9,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, - SingletonPromptInputs) +from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType, + SingletonPrompt) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt if TYPE_CHECKING: @@ -206,7 +206,7 @@ async def _tokenize_prompt_async( def _extract_prompt_components( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: @@ -216,7 +216,7 @@ def _extract_prompt_components( Arguments: * request_id - * inputs: single encoder or decoder input prompt + * prompt: single encoder or decoder input prompt * lora_request: this is only valid for decoder prompts Returns: @@ -226,24 +226,24 @@ def _extract_prompt_components( * multi_modal_data ''' - parsed = parse_singleton_prompt(inputs) + parsed = parse_singleton_prompt(prompt) if parsed["type"] == "str": - prompt = parsed["content"] + prompt_text = parsed["content"] prompt_token_ids = self._tokenize_prompt( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt = None + prompt_text = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt = parsed["content"]["prompt"] + prompt_text = parsed["content"]["prompt"] prompt_token_ids = self._tokenize_prompt( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) @@ -251,33 +251,33 @@ def _extract_prompt_components( else: assert_never(parsed) - return prompt, prompt_token_ids, multi_modal_data + return prompt_text, prompt_token_ids, multi_modal_data async def _extract_prompt_components_async( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: """Async version of :meth:`_extract_prompt_components`.""" - parsed = parse_singleton_prompt(inputs) + parsed = parse_singleton_prompt(prompt) if parsed["type"] == "str": - prompt = parsed["content"] + prompt_text = parsed["content"] prompt_token_ids = await self._tokenize_prompt_async( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt = None + prompt_text = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt = parsed["content"]["prompt"] + prompt_text = parsed["content"]["prompt"] prompt_token_ids = await self._tokenize_prompt_async( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) @@ -285,7 +285,7 @@ async def _extract_prompt_components_async( else: assert_never(parsed) - return prompt, prompt_token_ids, multi_modal_data + return prompt_text, prompt_token_ids, multi_modal_data def _build_enc_dec_llm_inputs( self, @@ -311,7 +311,7 @@ def _build_enc_dec_llm_inputs( def _process_encoder_decoder_prompt( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, ) -> EncoderDecoderLLMInputs: ''' @@ -339,7 +339,7 @@ def _process_encoder_decoder_prompt( Arguments: - * inputs: an input prompt + * prompt: an input prompt * request_id Returns: @@ -350,13 +350,13 @@ def _process_encoder_decoder_prompt( encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): encoder_comps = self._extract_prompt_components( - inputs["encoder_prompt"], + prompt["encoder_prompt"], request_id=request_id, ) - if (decoder_input := inputs["decoder_prompt"]) is None: + if (decoder_input := prompt["decoder_prompt"]) is None: decoder_comps = None, None, None else: decoder_comps = self._extract_prompt_components( @@ -365,7 +365,7 @@ def _process_encoder_decoder_prompt( ) else: encoder_comps = self._extract_prompt_components( - inputs, + prompt, request_id=request_id, ) @@ -375,20 +375,20 @@ def _process_encoder_decoder_prompt( async def _process_encoder_decoder_prompt_async( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, ) -> EncoderDecoderLLMInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): encoder_task = self._extract_prompt_components_async( - inputs["encoder_prompt"], + prompt["encoder_prompt"], request_id=request_id, ) - if (decoder_input := inputs["decoder_prompt"]) is None: + if (decoder_input := prompt["decoder_prompt"]) is None: encoder_comps = await encoder_task decoder_comps = None, None, None else: @@ -401,7 +401,7 @@ async def _process_encoder_decoder_prompt_async( encoder_task, decoder_task) else: encoder_comps = await self._extract_prompt_components_async( - inputs, + prompt, request_id=request_id, ) @@ -425,7 +425,7 @@ def _build_decoder_only_llm_inputs( def _process_decoder_only_prompt( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -436,7 +436,7 @@ def _process_decoder_only_prompt( Arguments: - * inputs: input prompt + * prompt: input prompt * request_id * lora_request * prompt_adapter_request @@ -447,7 +447,7 @@ def _process_decoder_only_prompt( ''' prompt_comps = self._extract_prompt_components( - inputs, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -459,14 +459,14 @@ def _process_decoder_only_prompt( async def _process_decoder_only_prompt_async( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" prompt_comps = await self._extract_prompt_components_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -478,7 +478,7 @@ async def _process_decoder_only_prompt_async( def preprocess( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -488,17 +488,17 @@ def preprocess( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return self._process_encoder_decoder_prompt( - inputs, + prompt, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return self._process_decoder_only_prompt( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -506,7 +506,7 @@ def preprocess( async def preprocess_async( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -516,17 +516,17 @@ async def preprocess_async( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return await self._process_encoder_decoder_prompt_async( - inputs, + prompt, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return await self._process_decoder_only_prompt_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, From 64840dfae48621c5c2004eb8f1cb7fba49f9b24e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A7=91=E8=8B=B1?= Date: Thu, 26 Sep 2024 00:37:41 +0800 Subject: [PATCH 38/50] [Frontend] MQLLMEngine supports profiling. (#8761) --- vllm/engine/multiprocessing/__init__.py | 8 +++++++- vllm/engine/multiprocessing/client.py | 23 ++++++++++++++++++----- vllm/engine/multiprocessing/engine.py | 21 ++++++++++++++++++++- 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 05067a6a192d5..6d6d7895b2101 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -107,7 +107,13 @@ class RPCStartupResponse: tracing_enabled: bool -RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest] +class RPCUProfileRequest(Enum): + START_PROFILE = 1 + STOP_PROFILE = 2 + + +RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, + RPCUProfileRequest] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 239ca52ef13e2..700e65000e052 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -21,7 +21,8 @@ IPC_OUTPUT_EXT, RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCError, RPCProcessRequest, - RPCStartupRequest, RPCStartupResponse) + RPCStartupRequest, RPCStartupResponse, + RPCUProfileRequest) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT from vllm.inputs import PromptType @@ -38,10 +39,10 @@ class MQClientClosedError(Exception): """Exception class raised when the client is used post-close. - + The client can be closed, which closes the ZMQ context. This normally - happens on server shutdown. In some cases, methods like abort and - do_log_stats will still be called and then try to open a socket, which + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which causes a ZMQError and creates a huge stack trace. So, we throw this error such that we can suppress it. """ @@ -345,7 +346,7 @@ async def do_log_stats(self): async def check_health(self): """ The check health loop probes the health status of the - Engine's health every N seconds and sets _errored_with + Engine's health every N seconds and sets _errored_with if the engine is unhealthy. """ if self._errored_with is not None: @@ -561,3 +562,15 @@ async def _process_request( await self.abort(request_id) finally: self.output_queues.pop(request_id) + + async def start_profile(self) -> None: + """Start profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket) + + async def stop_profile(self) -> None: + """Stop profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index b406d4a759667..eecca82cd2f7d 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -18,9 +18,11 @@ IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCError, RPCProcessRequest, - RPCStartupRequest, RPCStartupResponse) + RPCStartupRequest, RPCStartupResponse, + RPCUProfileRequest) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -249,6 +251,11 @@ def handle_new_input(self): self._handle_process_request(request) elif isinstance(request, RPCAbortRequest): self._handle_abort_request(request) + elif isinstance(request, RPCUProfileRequest): + if request == RPCUProfileRequest.START_PROFILE: + self.start_profile() + else: + self.stop_profile() else: raise ValueError("Unknown RPCRequest Type: " f"{type(request)}") @@ -356,6 +363,18 @@ def _set_errored(self, e: BaseException): def _alive(self): self._last_alive_time = time.time() + def start_profile(self) -> None: + if type(self.engine.model_executor) is GPUExecutor: + self.engine.model_executor.start_profile() + else: + self.engine.model_executor._run_workers("start_profile") + + def stop_profile(self) -> None: + if type(self.engine.model_executor) is GPUExecutor: + self.engine.model_executor.stop_profile() + else: + self.engine.model_executor._run_workers("stop_profile") + def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): From 873edda6cf8a2902e8b08eea0bf8f8f6d73704a8 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 25 Sep 2024 12:43:36 -0400 Subject: [PATCH 39/50] [Misc] Support FP8 MoE for compressed-tensors (#8588) --- tests/weight_loading/models-large.txt | 1 + vllm/model_executor/layers/fused_moe/layer.py | 9 +- .../compressed_tensors/compressed_tensors.py | 2 +- .../compressed_tensors_moe.py | 218 +++++++++++++++++- vllm/model_executor/models/phimoe.py | 4 +- 5 files changed, 226 insertions(+), 8 deletions(-) diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt index 2f5c6c5a117f3..3e6eba04f1a87 100644 --- a/tests/weight_loading/models-large.txt +++ b/tests/weight_loading/models-large.txt @@ -1,4 +1,5 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main +compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f6c6f5f529408..bce740d0db750 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -323,10 +323,12 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: - # compressed-tensors represents weights on disk which are flipped + # compressed-tensors checkpoints with packed weights are stored flipped + # TODO (mgoin): check self.quant_method.quant_config.quant_format + # against known CompressionFormat enum values that have this quality loaded_weight = loaded_weight.t().contiguous() if ( self.quant_method.__class__.__name__ - == "CompressedTensorsMoEMethod") else loaded_weight + == "CompressedTensorsWNA16MoEMethod") else loaded_weight if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " @@ -353,6 +355,9 @@ def weight_loader(self, param: torch.nn.Parameter, # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: + # this is needed for compressed-tensors only + loaded_weight = loaded_weight.to(param.data.device) + if param.data[expert_id] != 1 and (param.data[expert_id] - loaded_weight).abs() > 1e-5: raise ValueError( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index e536fae45c845..362feeef2e33c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -73,7 +73,7 @@ def get_quant_method( if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) if isinstance(layer, FusedMoE): - return CompressedTensorsMoEMethod(self) + return CompressedTensorsMoEMethod.get_moe_method(self) return None @classmethod diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 7dee2fca81153..6666a4bf1f26a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -5,12 +5,16 @@ import torch from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( WNA16_SUPPORTED_BITS) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - CompressionFormat) + CompressionFormat, QuantizationStrategy) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import is_hip, print_warning_once class GPTQMarlinState(Enum): @@ -18,11 +22,219 @@ class GPTQMarlinState(Enum): READY = enum.auto() -__all__ = ["CompressedTensorsMoEMethod"] +__all__ = [ + "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", + "CompressedTensorsWNA16MoEMethod" +] class CompressedTensorsMoEMethod(FusedMoEMethodBase): + @staticmethod + def get_moe_method( + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ) -> "CompressedTensorsMoEMethod": + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + weight_quant = quant_config.target_scheme_map["Linear"].get("weights") + input_quant = quant_config.target_scheme_map["Linear"].get( + "input_activations") + + if quant_config._is_wNa16_group_channel(weight_quant, input_quant): + return CompressedTensorsWNA16MoEMethod(quant_config) + elif quant_config._is_fp8_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Fp8MoEMethod(quant_config) + else: + raise RuntimeError( + f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") + + +class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( + "weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + + if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy == QuantizationStrategy.TENSOR): + raise ValueError( + "For FP8 Fused MoE layers, only per-tensor scales" + "for weights and activations are supported. Found " + f"{self.weight_quant}, {self.input_quant}") + + self.static_input_scales = not self.input_quant.dynamic + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.static_input_scales: + if (layer.w13_input_scale is None or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. ") + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) + + # If rocm, normalize the weights and scales to e4m3fnuz + if is_hip(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, + layer.w13_input_scale) + w2_weight, w2_weight_scale, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, + layer.w2_input_scale) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, + requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, + requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_experts(x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) + + +class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + def __init__( self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index a3555a294bb66..487d9fc2f4337 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -321,13 +321,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=True, - quant_config=None, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=True, - quant_config=None, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, From 4f1ba0844b83b4e7d0ff1672b7ba502ce8732f95 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Wed, 25 Sep 2024 10:36:26 -0700 Subject: [PATCH 40/50] Revert "rename PromptInputs and inputs with backward compatibility (#8760) (#8810) --- benchmarks/benchmark_latency.py | 8 +- .../dev/multimodal/multimodal_index.rst | 2 +- .../dev/offline_inference/llm_inputs.rst | 2 +- docs/source/models/vlm.rst | 2 +- tests/async_engine/test_async_llm_engine.py | 8 +- tests/entrypoints/llm/test_encode.py | 34 ++++++ tests/entrypoints/llm/test_generate.py | 37 ++++++ tests/mq_llm_engine/test_error_handling.py | 12 +- tests/mq_llm_engine/utils.py | 2 +- vllm/__init__.py | 4 +- vllm/engine/async_llm_engine.py | 110 +++--------------- vllm/engine/llm_engine.py | 52 ++------- vllm/engine/multiprocessing/__init__.py | 61 +--------- vllm/engine/multiprocessing/client.py | 95 +++------------ vllm/engine/multiprocessing/engine.py | 2 +- vllm/engine/protocol.py | 8 +- vllm/entrypoints/llm.py | 68 ++++++----- vllm/inputs/__init__.py | 20 +--- vllm/inputs/data.py | 48 +++----- vllm/inputs/parse.py | 22 ++-- vllm/inputs/preprocess.py | 86 +++++++------- 21 files changed, 245 insertions(+), 438 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index eadf994cacd34..a39d1cf842f06 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -61,7 +61,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_prompts: List[PromptType] = [{ + dummy_inputs: List[PromptInputs] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] @@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(dummy_prompts, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(dummy_prompts, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index e112b43aade5e..241b2ccd0991e 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -8,7 +8,7 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` -via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`. +via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`. Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities by following :ref:`this guide `. diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst index 0d47281db485e..9adf82d43f3e0 100644 --- a/docs/source/dev/offline_inference/llm_inputs.rst +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -1,7 +1,7 @@ LLM Inputs ========== -.. autodata:: vllm.inputs.PromptType +.. autodata:: vllm.inputs.PromptInputs .. autoclass:: vllm.inputs.TextPrompt :show-inheritance: diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index ca5b125369c85..08db891665044 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model. -To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`: +To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: * ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 1903a7582dc89..6cae76f74603d 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -86,19 +86,17 @@ class MockAsyncLLMEngine(AsyncLLMEngine): @pytest.mark.asyncio async def test_new_requests_event(): - params = SamplingParams() - engine = MockAsyncLLMEngine() engine.start_background_loop() await asyncio.sleep(0.01) assert engine.engine.step_calls == 0 - await engine.add_request("1", "", params) + await engine.add_request("1", "", None) await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 1 assert engine.engine.step_calls == 1 - await engine.add_request("2", "", params) + await engine.add_request("2", "", None) engine.engine.generate("2") await asyncio.sleep(0) await asyncio.sleep(0) @@ -113,7 +111,7 @@ async def test_new_requests_event(): await asyncio.sleep(0.001) assert engine.engine.step_calls == old_step_calls - await engine.add_request("3", "", params) + await engine.add_request("3", "", None) await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 3 assert engine.engine.step_calls == old_step_calls + 1 diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index 1885f2e168d80..d1056a0490509 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -49,6 +49,21 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput], assert [o.outputs for o in o1] == [o.outputs for o in o2] +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt', PROMPTS) +def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params) + + v2_output = llm.encode(prompt, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, @@ -64,6 +79,25 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, assert_outputs_equal(v1_output, v2_output) +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params) + + v2_output = llm.encode(PROMPTS, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.encode( + [{ + "prompt": p + } for p in PROMPTS], + pooling_params=pooling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + @pytest.mark.skip_global_cleanup def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): pooling_params = PoolingParams() diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 6543c4bb1b58e..cd989225e2483 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -47,6 +47,23 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): assert [o.outputs for o in o1] == [o.outputs for o in o2] +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt', PROMPTS) +def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.generate(prompts=prompt, + sampling_params=sampling_params) + + v2_output = llm.generate(prompt, sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.generate({"prompt": prompt}, + sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, @@ -62,6 +79,26 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, assert_outputs_equal(v1_output, v2_output) +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.generate(prompts=PROMPTS, + sampling_params=sampling_params) + + v2_output = llm.generate(PROMPTS, sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.generate( + [{ + "prompt": p + } for p in PROMPTS], + sampling_params=sampling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + @pytest.mark.skip_global_cleanup def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): sampling_params = SamplingParams(temperature=0.0, top_p=1.0) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 616a15a1328de..76b2f494d5b25 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket): # Throws an error in first forward pass. with pytest.raises(RAISED_ERROR): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket): # Engine is errored, should get ENGINE_DEAD_ERROR. with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket): # Generate call should throw ENGINE_DEAD_ERROR with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -160,7 +160,7 @@ async def test_failed_abort(tmp_socket): # with reference to the original KeyError("foo") with pytest.raises(MQEngineDeadError) as execinfo: async for _ in client.generate( - prompt="Hello my name is", + inputs="Hello my name is", sampling_params=SamplingParams(max_tokens=10), request_id=uuid.uuid4()): pass @@ -183,7 +183,7 @@ async def test_bad_request(tmp_socket): # Invalid request should fail, but not crash the server. with pytest.raises(ValueError): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-1", lora_request=LoRARequest( @@ -192,7 +192,7 @@ async def test_bad_request(tmp_socket): pass # This request should be okay. - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-2"): pass diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index 3ffa126070ca0..e27fd77923412 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -20,7 +20,7 @@ async def generate( count = 0 async for out in client.generate( request_id=request_id, - prompt="Hello my name is Robert and", + inputs="Hello my name is Robert and", sampling_params=SamplingParams(max_tokens=num_tokens, temperature=0)): diff --git a/vllm/__init__.py b/vllm/__init__.py index 8f477ea84756d..90363b3e49b73 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,7 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) @@ -19,7 +19,7 @@ "__version_tuple__", "LLM", "ModelRegistry", - "PromptType", + "PromptInputs", "TextPrompt", "TokensPrompt", "SamplingParams", diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 54c5af2fe3665..34e7e05341f02 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -2,8 +2,8 @@ import time import weakref from functools import partial -from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable, - List, Mapping, Optional, Set, Tuple, Type, Union, overload) +from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, + Mapping, Optional, Set, Tuple, Type, Union) from weakref import ReferenceType import vllm.envs as envs @@ -17,7 +17,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -28,7 +28,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import deprecate_kwargs, weak_bind +from vllm.utils import weak_bind logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -402,54 +402,17 @@ async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() - @overload # DEPRECATED async def add_request_async( self, request_id: str, - *, - inputs: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> None: - ... - - @overload - async def add_request_async( - self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> None: - ... - - @deprecate_kwargs( - "inputs", - additional_message="Please use the 'prompt' parameter instead.", - ) - async def add_request_async( - self, - request_id: str, - prompt: Optional[PromptType] = None, - params: Optional[Union[SamplingParams, PoolingParams]] = None, - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - *, - inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: """Async version of :meth:`add_request`.""" - if inputs is not None: - prompt = inputs - assert prompt is not None and params is not None - if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") @@ -457,7 +420,7 @@ async def add_request_async( arrival_time = time.time() preprocessed_inputs = await self.input_preprocessor.preprocess_async( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -811,55 +774,16 @@ async def run_engine_loop(engine_ref: ReferenceType): # This method does not need to be async, but kept that way # for backwards compatibility. - @overload # DEPRECATED - def add_request( - self, - request_id: str, - *, - inputs: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Coroutine[None, None, AsyncGenerator[Union[ - RequestOutput, EmbeddingRequestOutput], None]]: - ... - - @overload - def add_request( - self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Coroutine[None, None, AsyncGenerator[Union[ - RequestOutput, EmbeddingRequestOutput], None]]: - ... - - @deprecate_kwargs( - "inputs", - additional_message="Please use the 'prompt' parameter instead.", - ) async def add_request( self, request_id: str, - prompt: Optional[PromptType] = None, - params: Optional[Union[SamplingParams, PoolingParams]] = None, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - *, - inputs: Optional[PromptType] = None, # DEPRECATED + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: - if inputs is not None: - prompt = inputs - assert prompt is not None and params is not None - if not self.is_running: if self.start_engine_loop: self.start_background_loop() @@ -873,7 +797,7 @@ async def add_request( stream = self._request_tracker.add_request( request_id, verbose=self.log_requests, - prompt=prompt, + inputs=inputs, params=params, arrival_time=arrival_time or time.time(), lora_request=lora_request, @@ -884,7 +808,7 @@ async def add_request( async def generate( self, - prompt: PromptType, + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -898,7 +822,8 @@ async def generate( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -956,7 +881,7 @@ async def generate( """ async for output in await self.add_request( request_id, - prompt, + inputs, sampling_params, lora_request=lora_request, trace_headers=trace_headers, @@ -966,7 +891,7 @@ async def generate( async def encode( self, - prompt: PromptType, + inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -979,7 +904,8 @@ async def encode( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -1033,7 +959,7 @@ async def encode( """ async for output in await self.add_request( request_id, - prompt, + inputs, pooling_params, lora_request=lora_request, trace_headers=trace_headers, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7266d8e18a8ab..c341b236003a3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,7 +6,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence -from typing import Set, Type, Union, overload +from typing import Set, Type, Union import torch from typing_extensions import TypeVar @@ -29,7 +29,7 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptType) + InputRegistry, LLMInputs, PromptInputs) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -51,7 +51,7 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind +from vllm.utils import Counter, Device, weak_bind from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -689,51 +689,16 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() - @overload # DEPRECATED def add_request( self, request_id: str, - *, - inputs: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> None: - ... - - @overload - def add_request( - self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> None: - ... - - @deprecate_kwargs( - "inputs", - additional_message="Please use the 'prompt' parameter instead.", - ) - def add_request( - self, - request_id: str, - prompt: Optional[PromptType] = None, - params: Optional[Union[SamplingParams, PoolingParams]] = None, - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - *, - inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: """Add a request to the engine's request pool. @@ -743,7 +708,8 @@ def add_request( Args: request_id: The unique ID of the request. - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. params: Parameters for sampling or pooling. :class:`~vllm.SamplingParams` for text generation. @@ -778,10 +744,6 @@ def add_request( >>> # continue the request processing >>> ... """ - if inputs is not None: - prompt = inputs - assert prompt is not None and params is not None - if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") @@ -794,7 +756,7 @@ def add_request( arrival_time = time.time() preprocessed_inputs = self.input_preprocessor.preprocess( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 6d6d7895b2101..1603189979a2c 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -1,14 +1,13 @@ from dataclasses import dataclass from enum import Enum -from typing import List, Mapping, Optional, Union, overload +from typing import List, Mapping, Optional, Union from vllm import PoolingParams -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.utils import deprecate_kwargs VLLM_RPC_SUCCESS_STR = "SUCCESS" @@ -24,67 +23,13 @@ class MQEngineDeadError(RuntimeError): @dataclass class RPCProcessRequest: - prompt: PromptType + inputs: PromptInputs params: Union[SamplingParams, PoolingParams] request_id: str lora_request: Optional[LoRARequest] = None trace_headers: Optional[Mapping[str, str]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None - @overload # DEPRECATED - def __init__( - self, - *, - inputs: PromptType, - params: Union[SamplingParams, PoolingParams], - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> None: - ... - - @overload - def __init__( - self, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> None: - ... - - @deprecate_kwargs( - "inputs", - additional_message="Please use the 'prompt' parameter instead.", - ) - def __init__( - self, - prompt: Optional[PromptType] = None, - params: Optional[Union[SamplingParams, PoolingParams]] = None, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - *, - inputs: Optional[PromptType] = None, # DEPRECATED - ) -> None: - if inputs is not None: - prompt = inputs - assert (prompt is not None and params is not None - and request_id is not None) - - super().__init__() - - self.prompt = prompt - self.params = params - self.request_id = request_id - self.lora_request = lora_request - self.trace_headers = trace_headers - self.prompt_adapter_request = prompt_adapter_request - @dataclass class RPCError: diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 700e65000e052..0ee56f7bf8407 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -3,7 +3,7 @@ import pickle from contextlib import contextmanager, suppress from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, - Union, overload) + Union) import cloudpickle import zmq @@ -25,14 +25,13 @@ RPCUProfileRequest) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.utils import deprecate_kwargs logger = init_logger(__name__) @@ -368,45 +367,14 @@ def errored(self) -> bool: def dead_error(self) -> BaseException: return ENGINE_DEAD_ERROR(self._errored_with) - @overload # DEPRECATED def generate( self, - *, - inputs: PromptType, + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> AsyncGenerator[RequestOutput, None]: - ... - - @overload - def generate( - self, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> AsyncGenerator[RequestOutput, None]: - ... - - @deprecate_kwargs( - "inputs", - additional_message="Please use the 'prompt' parameter instead.", - ) - def generate( - self, - prompt: Optional[PromptType] = None, - sampling_params: Optional[SamplingParams] = None, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - *, - inputs: Optional[PromptType] = None # DEPRECATED + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. @@ -415,7 +383,8 @@ def generate( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -424,51 +393,17 @@ def generate( prompt_adapter_request: Prompt Adapter request to use for generation, if any. """ - if inputs is not None: - prompt = inputs - assert (prompt is not None and sampling_params is not None - and request_id is not None) - - return self._process_request(prompt, sampling_params, request_id, + return self._process_request(inputs, sampling_params, request_id, lora_request, trace_headers, prompt_adapter_request) - @overload # DEPRECATED def encode( self, - *, - inputs: PromptType, + inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: - ... - - @overload - def encode( - self, - prompt: PromptType, - pooling_params: PoolingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: - ... - - @deprecate_kwargs( - "inputs", - additional_message="Please use the 'prompt' parameter instead.", - ) - def encode( - self, - prompt: Optional[PromptType] = None, - pooling_params: Optional[PoolingParams] = None, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - *, - inputs: Optional[PromptType] = None # DEPRECATED ) -> AsyncGenerator[EmbeddingRequestOutput, None]: """Generate outputs for a request from an embedding model. @@ -477,7 +412,8 @@ def encode( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -488,17 +424,12 @@ def encode( The output `EmbeddingRequestOutput` objects from the LLMEngine for the request. """ - if inputs is not None: - prompt = inputs - assert (prompt is not None and pooling_params is not None - and request_id is not None) - - return self._process_request(prompt, pooling_params, request_id, + return self._process_request(inputs, pooling_params, request_id, lora_request, trace_headers) async def _process_request( self, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], request_id: str, lora_request: Optional[LoRARequest] = None, @@ -531,7 +462,7 @@ async def _process_request( request_bytes = pickle.dumps( RPCProcessRequest( - prompt=prompt, + inputs=inputs, params=params, request_id=request_id, lora_request=lora_request, diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index eecca82cd2f7d..1b2e7ccf8664f 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -278,7 +278,7 @@ def _handle_process_request(self, request: RPCProcessRequest): try: self.engine.add_request( request_id=request_id, - prompt=request.prompt, + inputs=request.inputs, params=request.params, lora_request=request.lora_request, trace_headers=request.trace_headers, diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index d0bbeb357b506..70444faa670a2 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -3,7 +3,7 @@ from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.inputs.data import PromptType +from vllm.inputs.data import PromptInputs from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -35,19 +35,19 @@ def dead_error(self) -> BaseException: def generate( self, - prompt: PromptType, + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: - """Generate outputs for a request.""" + """Generates outputs for a request""" ... def encode( self, - prompt: PromptType, + inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f4943cb38da44..77ae7b088398a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -12,7 +12,7 @@ apply_hf_chat_template, apply_mistral_chat_template, parse_chat_messages) -from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -293,8 +293,8 @@ def generate( @overload def generate( self, - prompts: Union[PromptType, Sequence[PromptType]], - /, + inputs: Union[PromptInputs, Sequence[PromptInputs]], + /, # We may enable `inputs` keyword after removing the old API *, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -304,13 +304,14 @@ def generate( ... @deprecate_kwargs( + "prompts", "prompt_token_ids", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'prompts' parameter instead.", + additional_message="Please use the 'inputs' parameter instead.", ) def generate( self, - prompts: Union[Union[PromptType, Sequence[PromptType]], + prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], Optional[Union[str, List[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -329,9 +330,7 @@ def generate( into a single list and pass it to this method. Args: - prompts: The prompts to the LLM. You may pass a sequence of prompts - for batch inference. See :class:`~vllm.inputs.PromptType` - for more details about the format of each prompts. + inputs: A list of inputs to generate completions for. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. @@ -359,13 +358,12 @@ def generate( "models (XForCausalLM, XForConditionalGeneration).") if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( + inputs = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) if isinstance(guided_options_request, dict): if len(guided_options_request) > 1: @@ -380,7 +378,7 @@ def generate( sampling_params = SamplingParams() self._validate_and_add_requests( - prompts=parsed_prompts, + inputs=inputs, params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -650,8 +648,8 @@ def encode( @overload def encode( self, - prompts: Union[PromptType, Sequence[PromptType]], - /, + inputs: Union[PromptInputs, Sequence[PromptInputs]], + /, # We may enable `inputs` keyword after removing the old API *, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -661,13 +659,14 @@ def encode( ... @deprecate_kwargs( + "prompts", "prompt_token_ids", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'prompts' parameter instead.", + additional_message="Please use the 'inputs' parameter instead.", ) def encode( self, - prompts: Union[Union[PromptType, Sequence[PromptType]], + prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], Optional[Union[str, List[str]]]] = None, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -683,9 +682,9 @@ def encode( into a single list and pass it to this method. Args: - prompts: The prompts to the LLM. You may pass a sequence of prompts - for batch inference. See :class:`~vllm.inputs.PromptType` - for more details about the format of each prompts. + inputs: The inputs to the LLM. You may pass a sequence of inputs for + batch inference. See :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. @@ -708,20 +707,19 @@ def encode( ) if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( + inputs = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() self._validate_and_add_requests( - prompts=parsed_prompts, + inputs=inputs, params=pooling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -765,9 +763,9 @@ def _convert_v1_inputs( raise ValueError("Either prompts or prompt_token_ids must be " "provided.") - parsed_prompts: List[PromptType] = [] + inputs: List[PromptInputs] = [] for i in range(num_requests): - item: PromptType + item: PromptInputs if prompts is not None: item = TextPrompt(prompt=prompts[i]) @@ -776,13 +774,13 @@ def _convert_v1_inputs( else: raise AssertionError - parsed_prompts.append(item) + inputs.append(item) - return parsed_prompts + return inputs def _validate_and_add_requests( self, - prompts: Union[PromptType, Sequence[PromptType]], + inputs: Union[PromptInputs, Sequence[PromptInputs]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], @@ -790,11 +788,11 @@ def _validate_and_add_requests( guided_options: Optional[GuidedDecodingRequest] = None, priority: Optional[List[int]] = None, ) -> None: - if isinstance(prompts, (str, dict)): + if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. - prompts = [prompts] + inputs = [inputs] - num_requests = len(prompts) + num_requests = len(inputs) if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") @@ -811,9 +809,9 @@ def _validate_and_add_requests( sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. - for i, prompt in enumerate(prompts): + for i, request_inputs in enumerate(inputs): self._add_request( - prompt, + request_inputs, params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, @@ -823,7 +821,7 @@ def _validate_and_add_requests( def _add_request( self, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -832,7 +830,7 @@ def _add_request( request_id = str(next(self.request_counter)) self.llm_engine.add_request( request_id, - prompt, + inputs, params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index a8c8672cb5fe7..0b08e9691f915 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,5 +1,5 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptType, SingletonPrompt, TextPrompt, + LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry @@ -16,8 +16,8 @@ __all__ = [ "TextPrompt", "TokensPrompt", - "PromptType", - "SingletonPrompt", + "PromptInputs", + "SingletonPromptInputs", "ExplicitEncoderDecoderPrompt", "LLMInputs", "EncoderDecoderLLMInputs", @@ -28,17 +28,3 @@ "InputContext", "InputRegistry", ] - - -def __getattr__(name: str): - if name == "PromptInput": - import warnings - - msg = ("PromptInput has been renamed to PromptType. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return PromptType - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 9e6238cb85ac0..75ab0c770155b 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -33,7 +33,7 @@ class TokensPrompt(TypedDict): """ -SingletonPrompt = Union[str, TextPrompt, TokensPrompt] +SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt] """ Set of possible schemas for a single LLM input: @@ -46,7 +46,7 @@ class TokensPrompt(TypedDict): the user desires to express both the encoder & decoder prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` -A prompt of type :class:`SingletonPrompt` may be employed +A prompt of type :class:`SingletonPromptInputs` may be employed as (1) input to a decoder-only model, (2) input to the encoder of an encoder/decoder model, in the scenario where the decoder-prompt is not specified explicitly, or @@ -55,33 +55,33 @@ class TokensPrompt(TypedDict): """ _T1_co = TypeVar("_T1_co", - bound=SingletonPrompt, - default=SingletonPrompt, + bound=SingletonPromptInputs, + default=SingletonPromptInputs, covariant=True) _T2_co = TypeVar("_T2_co", - bound=SingletonPrompt, - default=SingletonPrompt, + bound=SingletonPromptInputs, + default=SingletonPromptInputs, covariant=True) # TODO: Make fields ReadOnly once mypy supports it class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): - """ - Represents an encoder/decoder model input prompt, - comprising an explicit encoder prompt and a decoder prompt. + """Represents an encoder/decoder model input prompt, + comprising an explicit encoder prompt and a + decoder prompt. The encoder and decoder prompts, respectively, may formatted according to any of the - :class:`SingletonPrompt` schemas, and are not + :class:`SingletonPromptInputs` schemas, and are not required to have the same schema. Only the encoder prompt may have multi-modal data. Note that an :class:`ExplicitEncoderDecoderPrompt` may not be used as an input to a decoder-only model, - and that the :code:`encoder_prompt` and :code:`decoder_prompt` + and that the `encoder_prompt` and `decoder_prompt` fields of this data structure themselves must be - :class:`SingletonPrompt` instances. + :class:`SingletonPromptInputs` instances. """ encoder_prompt: _T1_co @@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): decoder_prompt: Optional[_T2_co] -PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] +PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] """ Set of possible schemas for an LLM input, including both decoder-only and encoder/decoder input types: @@ -140,8 +140,12 @@ class EncoderDecoderLLMInputs(LLMInputs): """ -_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) -_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) +_T1 = TypeVar("_T1", + bound=SingletonPromptInputs, + default=SingletonPromptInputs) +_T2 = TypeVar("_T2", + bound=SingletonPromptInputs, + default=SingletonPromptInputs) def build_explicit_enc_dec_prompt( @@ -172,17 +176,3 @@ def to_enc_dec_tuple_list( return [(enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"]) for enc_dec_prompt in enc_dec_prompts] - - -def __getattr__(name: str): - if name == "PromptInput": - import warnings - - msg = ("PromptInput has been renamed to PromptType. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return PromptType - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index e5fa1e4184277..ac9d355c64c80 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -5,7 +5,7 @@ from vllm.utils import is_list_of from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptType, SingletonPrompt, TextPrompt, + LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, TokensPrompt) @@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict): def parse_singleton_prompt( - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: - if isinstance(prompt, str): - return ParsedStrPrompt(type="str", content=prompt) - elif isinstance(prompt, dict): - if "prompt_token_ids" in prompt: + if isinstance(inputs, str): + return ParsedStrPrompt(type="str", content=inputs) + elif isinstance(inputs, dict): + if "prompt_token_ids" in inputs: return ParsedTokensPrompt(type="tokens", - content=prompt) # type: ignore - elif "prompt" in prompt: - return ParsedTextPrompt(type="text", content=prompt) + content=inputs) # type: ignore + elif "prompt" in inputs: + return ParsedTextPrompt(type="text", content=inputs) raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") def is_explicit_encoder_decoder_prompt( - prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: - return isinstance(prompt, dict) and "encoder_prompt" in prompt + inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]: + return isinstance(inputs, dict) and "encoder_prompt" in inputs def is_valid_encoder_decoder_llm_inputs( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 1f1b048d37e9b..be2aa5f8cb7d0 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -9,8 +9,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType, - SingletonPrompt) +from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, + SingletonPromptInputs) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt if TYPE_CHECKING: @@ -206,7 +206,7 @@ async def _tokenize_prompt_async( def _extract_prompt_components( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: @@ -216,7 +216,7 @@ def _extract_prompt_components( Arguments: * request_id - * prompt: single encoder or decoder input prompt + * inputs: single encoder or decoder input prompt * lora_request: this is only valid for decoder prompts Returns: @@ -226,24 +226,24 @@ def _extract_prompt_components( * multi_modal_data ''' - parsed = parse_singleton_prompt(prompt) + parsed = parse_singleton_prompt(inputs) if parsed["type"] == "str": - prompt_text = parsed["content"] + prompt = parsed["content"] prompt_token_ids = self._tokenize_prompt( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt_text = None + prompt = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt_text = parsed["content"]["prompt"] + prompt = parsed["content"]["prompt"] prompt_token_ids = self._tokenize_prompt( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -251,33 +251,33 @@ def _extract_prompt_components( else: assert_never(parsed) - return prompt_text, prompt_token_ids, multi_modal_data + return prompt, prompt_token_ids, multi_modal_data async def _extract_prompt_components_async( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: """Async version of :meth:`_extract_prompt_components`.""" - parsed = parse_singleton_prompt(prompt) + parsed = parse_singleton_prompt(inputs) if parsed["type"] == "str": - prompt_text = parsed["content"] + prompt = parsed["content"] prompt_token_ids = await self._tokenize_prompt_async( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt_text = None + prompt = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt_text = parsed["content"]["prompt"] + prompt = parsed["content"]["prompt"] prompt_token_ids = await self._tokenize_prompt_async( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -285,7 +285,7 @@ async def _extract_prompt_components_async( else: assert_never(parsed) - return prompt_text, prompt_token_ids, multi_modal_data + return prompt, prompt_token_ids, multi_modal_data def _build_enc_dec_llm_inputs( self, @@ -311,7 +311,7 @@ def _build_enc_dec_llm_inputs( def _process_encoder_decoder_prompt( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, ) -> EncoderDecoderLLMInputs: ''' @@ -339,7 +339,7 @@ def _process_encoder_decoder_prompt( Arguments: - * prompt: an input prompt + * inputs: an input prompt * request_id Returns: @@ -350,13 +350,13 @@ def _process_encoder_decoder_prompt( encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): encoder_comps = self._extract_prompt_components( - prompt["encoder_prompt"], + inputs["encoder_prompt"], request_id=request_id, ) - if (decoder_input := prompt["decoder_prompt"]) is None: + if (decoder_input := inputs["decoder_prompt"]) is None: decoder_comps = None, None, None else: decoder_comps = self._extract_prompt_components( @@ -365,7 +365,7 @@ def _process_encoder_decoder_prompt( ) else: encoder_comps = self._extract_prompt_components( - prompt, + inputs, request_id=request_id, ) @@ -375,20 +375,20 @@ def _process_encoder_decoder_prompt( async def _process_encoder_decoder_prompt_async( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, ) -> EncoderDecoderLLMInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): encoder_task = self._extract_prompt_components_async( - prompt["encoder_prompt"], + inputs["encoder_prompt"], request_id=request_id, ) - if (decoder_input := prompt["decoder_prompt"]) is None: + if (decoder_input := inputs["decoder_prompt"]) is None: encoder_comps = await encoder_task decoder_comps = None, None, None else: @@ -401,7 +401,7 @@ async def _process_encoder_decoder_prompt_async( encoder_task, decoder_task) else: encoder_comps = await self._extract_prompt_components_async( - prompt, + inputs, request_id=request_id, ) @@ -425,7 +425,7 @@ def _build_decoder_only_llm_inputs( def _process_decoder_only_prompt( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -436,7 +436,7 @@ def _process_decoder_only_prompt( Arguments: - * prompt: input prompt + * inputs: input prompt * request_id * lora_request * prompt_adapter_request @@ -447,7 +447,7 @@ def _process_decoder_only_prompt( ''' prompt_comps = self._extract_prompt_components( - prompt, + inputs, request_id=request_id, lora_request=lora_request, ) @@ -459,14 +459,14 @@ def _process_decoder_only_prompt( async def _process_decoder_only_prompt_async( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" prompt_comps = await self._extract_prompt_components_async( - prompt, + inputs, request_id=request_id, lora_request=lora_request, ) @@ -478,7 +478,7 @@ async def _process_decoder_only_prompt_async( def preprocess( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -488,17 +488,17 @@ def preprocess( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return self._process_encoder_decoder_prompt( - prompt, + inputs, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return self._process_decoder_only_prompt( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -506,7 +506,7 @@ def preprocess( async def preprocess_async( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -516,17 +516,17 @@ async def preprocess_async( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return await self._process_encoder_decoder_prompt_async( - prompt, + inputs, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return await self._process_decoder_only_prompt_async( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, From 770ec6024fc00cd696899f5c6fdc53b7148876e6 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 25 Sep 2024 13:29:32 -0700 Subject: [PATCH 41/50] [Model] Add support for the multi-modal Llama 3.2 model (#8811) Co-authored-by: simon-mo Co-authored-by: Chang Su Co-authored-by: Simon Mo Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: Roger Wang --- docs/source/models/supported_models.rst | 5 + examples/offline_inference_vision_language.py | 24 + examples/openai_vision_api_client.py | 4 +- requirements-common.txt | 2 +- .../vision_language/__init__.py | 0 .../vision_language/test_mllama.py | 283 ++++ vllm/config.py | 4 +- vllm/engine/llm_engine.py | 6 +- vllm/entrypoints/chat_utils.py | 28 +- vllm/entrypoints/openai/serving_chat.py | 2 + vllm/inputs/data.py | 6 + vllm/inputs/preprocess.py | 22 +- vllm/inputs/registry.py | 54 +- vllm/model_executor/models/__init__.py | 2 + vllm/model_executor/models/mllama.py | 1135 +++++++++++++++++ vllm/multimodal/base.py | 6 + vllm/multimodal/image.py | 5 + vllm/sequence.py | 12 +- vllm/transformers_utils/config.py | 17 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/mllama.py | 28 + vllm/transformers_utils/tokenizer.py | 1 - vllm/worker/enc_dec_model_runner.py | 40 +- vllm/worker/utils.py | 4 - 24 files changed, 1647 insertions(+), 45 deletions(-) create mode 100644 tests/models/encoder_decoder/vision_language/__init__.py create mode 100644 tests/models/encoder_decoder/vision_language/test_mllama.py create mode 100644 vllm/model_executor/models/mllama.py create mode 100644 vllm/transformers_utils/configs/mllama.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index d86d0860f7f29..bf690726a637b 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -254,6 +254,11 @@ Multimodal Language Models - Image\ :sup:`+` - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - + * - :code:`MllamaForConditionalGeneration` + - Llama 3.2 + - Image + - :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc. + - * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - Image\ :sup:`E` diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 6675aa0109a68..6d34621a8a9bc 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -242,6 +242,29 @@ def run_qwen2_vl(question, modality): return llm, prompt, stop_token_ids +# LLama +def run_mllama(question, modality): + assert modality == "image" + + model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" + + # Note: The default setting of max_num_seqs (256) and + # max_model_len (131072) for this model may cause OOM. + # You may lower either to run this example on lower-end GPUs. + + # The configuration below has been confirmed to launch on a + # single H100 GPU. + llm = LLM( + model=model_name, + max_num_seqs=16, + enforce_eager=True, + ) + + prompt = f"<|image|><|begin_of_text|>{question}" + stop_token_ids = None + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -256,6 +279,7 @@ def run_qwen2_vl(question, modality): "internvl_chat": run_internvl, "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, + "mllama": run_mllama, } diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index 1ba702ef019e4..71ae03e4d148b 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -38,7 +38,7 @@ "content": [ { "type": "text", - "text": "What’s in this image?" + "text": "What's in this image?" }, { "type": "image_url", @@ -75,7 +75,7 @@ def encode_image_base64_from_url(image_url: str) -> str: "content": [ { "type": "text", - "text": "What’s in this image?" + "text": "What's in this image?" }, { "type": "image_url", diff --git a/requirements-common.txt b/requirements-common.txt index c113ff3630425..2fc89c026901b 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -4,7 +4,7 @@ numpy < 2.0.0 requests tqdm py-cpuinfo -transformers >= 4.43.2 # Required for Chameleon and Llama 3.1 hotfox. +transformers >= 4.45.0 # Required for Llama 3.2. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. fastapi < 0.113.0; python_version < '3.9' diff --git a/tests/models/encoder_decoder/vision_language/__init__.py b/tests/models/encoder_decoder/vision_language/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py new file mode 100644 index 0000000000000..cda0926d0baf9 --- /dev/null +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -0,0 +1,283 @@ +from typing import List, Optional, Tuple, Type, overload + +import pytest +from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, + BatchEncoding) + +from vllm.multimodal.utils import rescale_image_size +from vllm.sequence import SampleLogprobs + +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) +from ....utils import multi_gpu_test +from ...utils import check_logprobs_close + +_LIMIT_IMAGE_PER_PROMPT = 1 + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "<|image|><|begin_of_text|>The meaning of the image is", + "cherry_blossom": + "<|image|><|begin_of_text|>The city is", +}) + +text_only_prompts = [ + "The color of the sky is blue but sometimes it can also be", +] + +models = [ + "meta-llama/Llama-3.2-11B-Vision-Instruct", +] + + +def vllm_to_hf_output(vllm_output: Tuple[List[int], str, + Optional[SampleLogprobs]], + model: str): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_index + + tokenizer = AutoTokenizer.from_pretrained(model) + eos_token_id = tokenizer.eos_token_id + + hf_output_ids = [ + token_id for idx, token_id in enumerate(output_ids) + if token_id != image_token_id or output_ids[idx - 1] != image_token_id + ] + + assert output_str[0] == " " + hf_output_str = output_str[1:] + if hf_output_ids[-1] == eos_token_id: + hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) + + return hf_output_ids, hf_output_str, out_logprobs + + +@overload +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + ... + + +@overload +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + sizes: List[Tuple[int, int]], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + ... + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: Optional[List[float]] = None, + sizes: Optional[List[Tuple[int, int]]] = None, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + images = [asset.pil_image for asset in image_assets] + + if size_factors is not None: + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + elif sizes is not None: + inputs_per_image = [( + [ + prompt if size is not None else text_only_prompts[0] + for size in sizes + ], + [ + image.resize(size) if size is not None else None + for size in sizes + ], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + if len(sizes) == 0: + inputs_per_image.append( + (text_only_prompts, [None] * len(text_only_prompts))) + else: + raise ValueError("You must provide either `size_factors` or `sizes`") + + _run_test(hf_runner, + vllm_runner, + inputs_per_image, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend) + + +def _run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test are from IMAGE_ASSETS. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding MultiModalConfig as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + dtype=dtype, + max_num_seqs=16, + max_model_len=4096, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT + }) as vllm_model: + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs + ] + + def process(hf_inputs: BatchEncoding): + return hf_inputs + + from transformers import AutoConfig + from transformers.models.mllama import MllamaConfig as MllamaConfigHf + + # use transformer's MllamaConfig for hf_runner + # and vllm's MllamaConfig for vllm_runner + AutoConfig.register("mllama", MllamaConfigHf, exist_ok=True) + with hf_runner(model, + dtype=dtype, + postprocess_inputs=process, + auto_cls=AutoModelForVision2Seq) as hf_model: + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs + ] + + from vllm.transformers_utils.configs.mllama import MllamaConfig + AutoConfig.register("mllama", MllamaConfig, exist_ok=True) + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, model) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "sizes", + [ + # Text only + [], + # Single-size + [(512, 512)], + # Single-size, batched + [(512, 512), (512, 512), (512, 512)], + # Multi-size, batched + [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), + (1024, 1024), (512, 1536), (512, 2028)], + # Multi-size, batched, including text only + [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), + (1024, 1024), (512, 1536), (512, 2028), None], + # mllama has 8 possible aspect ratios, carefully set the sizes + # to cover all of them + ], +) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype, + max_tokens, num_logprobs) -> None: + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + sizes=sizes, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "sizes", + [ + [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), + (1024, 1024), (512, 1536), (512, 2028), None], + ], +) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_distributed(hf_runner, vllm_runner, image_assets, model, sizes, + dtype, max_tokens, num_logprobs) -> None: + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + sizes=sizes, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=2, + ) diff --git a/vllm/config.py b/vllm/config.py index 308f29a3dc371..108badf150c86 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -576,7 +576,9 @@ def get_multimodal_config(self) -> "MultiModalConfig": @property def is_encoder_decoder_model(self) -> bool: """Extract the HF encoder/decoder model flag.""" - return getattr(self.hf_config, "is_encoder_decoder", False) + return getattr(self.hf_config, "is_encoder_decoder", False) or ( + (hasattr(self.hf_config, "text_config") and getattr( + self.hf_config.text_config, "is_encoder_decoder", False))) @property def is_embedding_model(self) -> bool: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c341b236003a3..768ac69c3692d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1734,7 +1734,11 @@ def is_embedding_model(self): def _validate_model_inputs(self, inputs: Union[LLMInputs, EncoderDecoderLLMInputs]): - if self.is_encoder_decoder_model(): + if self.model_config.is_multimodal_model: + # For encoder-decoder multimodal models, the max_prompt_len + # restricts the decoder prompt length + prompt_ids = inputs.get("prompt_token_ids") + elif self.is_encoder_decoder_model(): prompt_ids = inputs.get("encoder_prompt_token_ids") else: prompt_ids = inputs.get("prompt_token_ids") diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index f1ce2c36fcceb..4a575ae8f8537 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -159,6 +159,8 @@ def _placeholder_str(self, modality: ModalityStr, hf_config.image_token_index) if model_type in ("chameleon", "internvl_chat"): return "" + if model_type == "mllama": + return "<|image|>" if model_type == "qwen2_vl": return "<|vision_start|><|image_pad|><|vision_end|>" @@ -358,6 +360,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], _ImageParser = partial(cast, ChatCompletionContentPartImageParam) _AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) +MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'} def _parse_chat_message_content_parts( @@ -368,7 +371,11 @@ def _parse_chat_message_content_parts( texts: List[str] = [] mm_parser = mm_tracker.create_parser() + keep_multimodal_content = \ + mm_tracker._model_config.hf_config.model_type in \ + MODEL_KEEP_MULTI_MODAL_CONTENT + has_image = False for part in parts: part_type = part["type"] if part_type == "text": @@ -383,6 +390,7 @@ def _parse_chat_message_content_parts( "will be ignored.") mm_parser.parse_image(image_url["url"]) + has_image = True elif part_type == "audio_url": audio_url = _AudioParser(part)["audio_url"] @@ -394,12 +402,20 @@ def _parse_chat_message_content_parts( raise NotImplementedError(f"Unknown part type: {part_type}") text_prompt = "\n".join(texts) - mm_placeholder_counts = mm_parser.mm_placeholder_counts() - if mm_placeholder_counts: - text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, - text_prompt) - - return [ConversationMessage(role=role, content=text_prompt)] + if keep_multimodal_content: + text_prompt = "\n".join(texts) + role_content = [{'type': 'text', 'text': text_prompt}] + + if has_image: + role_content = [{'type': 'image'}] + role_content + return [ConversationMessage(role=role, + content=role_content)] # type: ignore + else: + mm_placeholder_counts = mm_parser.mm_placeholder_counts() + if mm_placeholder_counts: + text_prompt = _get_full_multimodal_text_prompt( + mm_placeholder_counts, text_prompt) + return [ConversationMessage(role=role, content=text_prompt)] # No need to validate using Pydantic again diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0321ea98ec742..94076ea3a51db 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -309,6 +309,8 @@ async def chat_completion_stream_generator( async for res in result_generator: if res.prompt_token_ids is not None: num_prompt_tokens = len(res.prompt_token_ids) + if res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(res.encoder_prompt_token_ids) # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 75ab0c770155b..a71e9a7b5db66 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -139,6 +139,12 @@ class EncoderDecoderLLMInputs(LLMInputs): available. """ + encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] + """ + Optional multi-modal data to pass to the encoder model, + if the model supports it. + """ + _T1 = TypeVar("_T1", bound=SingletonPromptInputs, diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index be2aa5f8cb7d0..bee3d1ed75cbb 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -128,6 +128,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: def _prepare_decoder_input_ids_for_generation( self, decoder_input_ids: Optional[List[int]], + force_bos: bool = True, ) -> List[int]: """ Prepares `decoder_input_ids` for generation with encoder-decoder models. @@ -157,8 +158,8 @@ def _prepare_decoder_input_ids_for_generation( # use decoder_start_token_id as decoder_input_ids decoder_input_ids = self._get_default_enc_dec_decoder_prompt() - if (len(decoder_input_ids) == 0 - or decoder_input_ids[0] != decoder_start_token_id): + if force_bos and (len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id): decoder_input_ids = [decoder_start_token_id] + decoder_input_ids return decoder_input_ids @@ -295,18 +296,25 @@ def _build_enc_dec_llm_inputs( encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps - if encoder_mm_data is not None or decoder_mm_data is not None: - raise ValueError("Multi-modal encoder-decoder models are " - "not supported yet") + if decoder_mm_data is not None: + raise ValueError( + "Multi-modality decoder inputs of encoder-decoder models are " + "not supported yet") - decoder_prompt_ids = ( - self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) + # For Multi-Modal models (e.g., mllama), the text input can be + # <|image|><|begin_of_text|>hello world. And we should not add + # another <|begin_of_text|> to the beginning. + decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation( + decoder_prompt_ids, + force_bos=(encoder_mm_data is None and decoder_mm_data is None))) return EncoderDecoderLLMInputs( prompt_token_ids=decoder_prompt_ids, prompt=decoder_prompt, + multi_modal_data=decoder_mm_data, encoder_prompt_token_ids=encoder_prompt_ids, encoder_prompt=encoder_prompt, + encoder_multi_modal_data=encoder_mm_data, ) def _process_encoder_decoder_prompt( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 6ab23d1c4b769..159d958ebf671 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -112,6 +112,8 @@ class InputRegistry: def __init__(self) -> None: self._dummy_factories_by_model_type: Dict[Type[nn.Module], DummyDataFactory] = {} + self._dummy_encoder_factories_by_model_type: Dict[ + Type[nn.Module], DummyDataFactory] = {} self._input_processors_by_model_type: Dict[Type[nn.Module], InputProcessor] = {} @@ -162,11 +164,44 @@ def _get_dummy_data_factory(self, model_cls: Type[nn.Module]): return self._dummy_factories_by_model_type \ .get(model_cls, self._default_dummy_data_factory) + def register_dummy_encoder_data(self, factory: DummyDataFactory): + """ + Register a dummy encoder data factory to a model class + + This is similar to :meth:`~register_dummy_data`, but for encoder input. + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._dummy_encoder_factories_by_model_type: + logger.warning( + "Model class %s already has dummy encoder data " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._dummy_encoder_factories_by_model_type[model_cls] = factory + + return model_cls + + return wrapper + + def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]): + if model_cls in self._dummy_encoder_factories_by_model_type: + dummy_factory = self._dummy_encoder_factories_by_model_type[ + model_cls] + else: + logger.warning( + "No dummy encoder data factory registered to %s. " + "Using the dummy data factory for the model instead.", + model_cls) + dummy_factory = self._get_dummy_data_factory(model_cls) + return dummy_factory + def dummy_data_for_profiling( self, model_config: "ModelConfig", seq_len: int, mm_registry: "MultiModalRegistry", + is_encoder_data: bool = False, ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ Create dummy data for profiling the memory usage of a model. @@ -184,8 +219,10 @@ def dummy_data_for_profiling( from vllm.model_executor.model_loader import get_model_architecture model_cls, _ = get_model_architecture(model_config) - dummy_factory = self._get_dummy_data_factory(model_cls) - + if is_encoder_data: + dummy_factory = self._get_dummy_encoder_data_factory(model_cls) + else: + dummy_factory = self._get_dummy_data_factory(model_cls) mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) mm_processor_kwargs = get_allowed_kwarg_only_overrides( dummy_factory, overrides=model_config.mm_processor_kwargs) @@ -196,10 +233,15 @@ def dummy_data_for_profiling( # Having more tokens is over-conservative but otherwise fine num_tokens = seq_data.prompt_token_ids - assert len(num_tokens) >= seq_len, ( - f"Expected at least {seq_len} dummy tokens for profiling, " - f"but found {len(num_tokens)} tokens instead.") - + if len(num_tokens) < seq_len: + if is_encoder_data: + logger.warning( + "Expected at least %d dummy encoder tokens for profiling, " + "but found %d tokens instead.", seq_len, len(num_tokens)) + else: + raise AssertionError( + f"Expected at least {seq_len} dummy tokens for profiling, " + f"but found {len(num_tokens)} tokens instead.") if mm_data is not None: for k, v in mm_data.items(): num_items = len(v) if isinstance(v, list) else 1 diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 3f52eb44edfff..3a6fa9e26ff4b 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -101,6 +101,8 @@ "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), "UltravoxModel": ("ultravox", "UltravoxModel"), + "MllamaForConditionalGeneration": ("mllama", + "MllamaForConditionalGeneration"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py new file mode 100644 index 0000000000000..aa868a3b8da28 --- /dev/null +++ b/vllm/model_executor/models/mllama.py @@ -0,0 +1,1135 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mllama model.""" +import math +from array import array +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers.models.mllama.configuration_mllama as config_mllama +from PIL import Image +from torch import nn +from transformers.modeling_outputs import (BaseModelOutput, + CausalLMOutputWithPast) +from transformers.models.mllama.image_processing_mllama import ( + get_optimal_tiled_canvas) + +import vllm.distributed.parallel_state as ps +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.config import CacheConfig, MultiModalConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData + +from .clip import CLIPMLP +from .interfaces import SupportsMultiModal +from .llama import LlamaDecoderLayer, LlamaMLP + +logger = init_logger(__name__) +MLLAMA_IMAGE_TOKEN_ID = 128256 +MLLAMA_IMAGE_TOKEN = "<|image|>" + + +class MllamaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: """ + """(batch_size, max_num_image, max_num_chunk, num_channel, height, width)""" + aspect_ratio_ids: torch.Tensor + """Shape: `(batch_size, max_num_image)`""" + aspect_ratio_mask: torch.Tensor + """Shape: `(batch_size, max_num_image, max_num_tiles)`""" + + +# TODO: support LlamaImageEmbeddingInputs + + +def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): + # move encoder_prompt to prompt + if llm_inputs.get("prompt") is None: + llm_inputs["prompt"] = llm_inputs["encoder_prompt"] + llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] + + # process multi-modal data + assert "decoder_multi_modal_data" not in llm_inputs, \ + "multi-modal data should be put in encoder message of mllama" + multi_modal_data = llm_inputs.get("encoder_multi_modal_data") + + if multi_modal_data is None or "image" not in multi_modal_data \ + or multi_modal_data["image"] is None: + # text-only + llm_inputs["encoder_prompt"] = "" + llm_inputs["encoder_prompt_token_ids"] = [] + llm_inputs["encoder_multi_modal_data"] = {} + return llm_inputs + + # get num_tiles + if isinstance(multi_modal_data['image'], Image.Image): + multi_modal_data['image'] = [multi_modal_data['image']] + hf_config = ctx.model_config.hf_config + num_tiles = 0 + for image in multi_modal_data["image"]: + width, height = image.size + tile_size = hf_config.vision_config.image_size + canvas_height, canvas_width = get_optimal_tiled_canvas( + image_height=height, + image_width=width, + max_image_tiles=hf_config.vision_config.max_num_tiles, + tile_size=tile_size, + ) + num_tiles_height = canvas_height // tile_size + num_tiles_width = canvas_width // tile_size + num_tiles += num_tiles_height * num_tiles_width + + # set encoder prompt based on num_tiles + assert hf_config.vision_config.image_size % 14 == 0, \ + "chunk size should be multiple of 14" + token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 + num_tokens = num_tiles * token_per_chunk + llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens + llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID + ] * num_tokens + + return llm_inputs + + +def get_max_mllama_image_tokens(ctx: InputContext) -> int: + hf_config = ctx.model_config.hf_config + token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 + return hf_config.vision_config.max_num_tiles * token_per_chunk + + +def dummy_decoder_seq_data(seq_len: int, num_images: int): + # <|image|> * num_images + 0 * (seq_len - num_images) + assert seq_len >= num_images, \ + "seq_len should be greater than or equal to num_images" + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [MLLAMA_IMAGE_TOKEN_ID]) * num_images + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - num_images) + return SequenceData(token_ids) + + +def dummy_encoder_seq_data(ctx: InputContext, num_images: int): + num_tokens = get_max_mllama_image_tokens(ctx) * num_images + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [MLLAMA_IMAGE_TOKEN_ID]) * num_tokens + return SequenceData(token_ids) + + +def dummy_image(num_images: int, ): + width = height = 1024 + image = Image.new("RGB", (width, height), color=0) + return {"image": image if num_images == 1 else [image] * num_images} + + +def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] + return dummy_decoder_seq_data(seq_len, num_images), None + + +def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] + return dummy_encoder_seq_data(ctx, num_images), dummy_image(num_images) + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, + 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length) + attention_mask = attention_mask.reshape(batch_size, + max_num_tiles * target_length, 1) + attention_mask = attention_mask @ attention_mask.transpose( + -1, -2) * torch.finfo(dtype).min + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +class ColumnParallelConv2dPatch(torch.nn.Module): + """Conv2D Patching layer with model parallelism. + Column parallel over unfolded input. + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + Input: (bsz, in_channels, width, height) + Output: (bsz, num_tokens, out_channels) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + bias: bool = False, + ) -> None: + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) + self._linear = ColumnParallelLinear( + in_channels * kernel_size[0] * kernel_size[1], + out_channels, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._unfold(x) + x = x.permute(0, 2, 1) + x, _ = self._linear(x) + return x + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + + def __init__(self, + config: config_mllama.MllamaVisionConfig, + is_gated: bool = True): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.is_gated = is_gated + + self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.hidden_size) + if is_gated: + self.gate = nn.Parameter(torch.zeros(1)) + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, + self.hidden_size) + + if self.is_gated: + embeddings = embeddings * self.gate.tanh() + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaPrecomputedPositionEmbedding(nn.Module): + + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.num_patches = (config.image_size // config.patch_size)**2 + 1 + self.hidden_size = config.hidden_size + self.scale = config.hidden_size**-0.5 + + self.gate = nn.Parameter(torch.zeros(1)) + + # position embedding + position_embedding = torch.randn(self.num_patches, self.hidden_size) + self.embedding = nn.Parameter(self.scale * position_embedding) + + # tile position embedding + self.tile_embedding = nn.Embedding( + self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.num_patches * self.hidden_size) + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + # position embeddings + gated_position_embedding = (1 - self.gate.tanh()) * self.embedding + hidden_state = hidden_state + gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size) + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size) + gated_tile_position_embedding = self.gate.tanh( + ) * tile_position_embedding + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +# TODO: support other attention backends for attention in vision model +class MllamaVisionSdpaAttention(nn.Module): + + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + + model_parallel_size = get_tensor_model_parallel_world_size() + self.embed_dim = config.hidden_size + self.num_heads = config.attention_heads + self.head_dim = config.hidden_size // config.attention_heads + self.num_local_heads = self.num_heads // model_parallel_size + self.q_size = self.num_local_heads * self.head_dim + self.kv_size = self.num_local_heads * self.head_dim + + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=False, + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.embed_dim, + bias=False, + input_is_parallel=True, + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_state) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(q.shape[0], q.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + k = k.view(k.shape[0], k.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + v = v.view(v.shape[0], v.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + + # TODO: remove padding in image encoder + attn_output = F.scaled_dot_product_attention(q, + k, + v, + attn_mask=attention_mask, + dropout_p=0.0) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(attn_output.shape[0], + attn_output.shape[1], -1) + output, _ = self.o_proj(attn_output) + return output + + +class MllamaVisionEncoderLayer(nn.Module): + + def __init__(self, + config: config_mllama.MllamaVisionConfig, + is_gated: bool = False): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.attention_heads + self.is_gated = is_gated + self.intermediate_size = config.intermediate_size + + self.self_attn = MllamaVisionSdpaAttention(config) + self.mlp = CLIPMLP(config) + + self.input_layernorm = nn.LayerNorm(self.hidden_size, + eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, + eps=config.norm_eps) + + # there used to be an if else here, no code path + if is_gated: + self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4) + self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, + attention_mask=attention_mask) + gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() + hidden_state = residual + gate_attn * hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() + hidden_state = residual + gate_ffn * hidden_state + + return hidden_state + + +class MllamaVisionEncoder(nn.Module): + + def __init__(self, + config: config_mllama.MllamaVisionConfig, + num_layers=32, + is_gated=False, + output_hidden_states=None): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + MllamaVisionEncoderLayer(config, is_gated) + for _ in range(num_layers) + ]) + self.output_hidden_states = output_hidden_states or [] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Union[Tuple, BaseModelOutput]: + encoder_states = () + + for i, encoder_layer in enumerate(self.layers): + if i in self.output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) + + if len(self.layers) - 1 in self.output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + + return hidden_states, encoder_states + + +class MllamaVisionModel(nn.Module): + + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + self.image_size = config.image_size + self.patch_size = config.patch_size + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.in_channels = config.num_channels + self.intermediate_layers_indices = config.intermediate_layers_indices + + self.num_patches = (self.image_size // self.patch_size)**2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = ColumnParallelConv2dPatch( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.class_embedding = nn.Parameter(self.scale * + torch.randn(self.hidden_size)) + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( + config) + + self.pre_tile_positional_embedding = \ + MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) + self.post_tile_positional_embedding = \ + MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) + + # layer norms + self.layernorm_pre = nn.LayerNorm(self.hidden_size) + self.layernorm_post = nn.LayerNorm(self.hidden_size) + + # encoders + self.transformer = MllamaVisionEncoder( + config, + config.num_hidden_layers, + is_gated=False, + output_hidden_states=config.intermediate_layers_indices) + self.global_transformer = MllamaVisionEncoder(config, + config.num_global_layers, + is_gated=True) + + def apply_class_embedding(self, + hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, + hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward(self, pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor) -> torch.Tensor: + batch_size, num_concurrent_media, num_tiles, num_channels, \ + height, width = pixel_values.shape + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, + height, width) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1) + + # patch embedding + patch_embeds = self.patch_embedding( + pixel_values.to(self.layernorm_pre.weight.dtype)) + hidden_state = patch_embeds + hidden_state = ps.get_tp_group().all_gather(hidden_state) + + # tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, -1, dim) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + + # apply cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # apply position embeddings + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, num_patches, dim) + hidden_state = self.gated_positional_embedding(hidden_state, + aspect_ratio_ids) + + # apply encoder + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, 0, 0, num_padding_patches + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + attention_mask = aspect_ratio_mask.reshape( + batch_size * num_concurrent_media, -1) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.layernorm_pre.weight.dtype, + ) + + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, + dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + ) + hidden_state, intermediate_hidden_states = output[0], output[1] + intermediate_hidden_states = torch.stack(intermediate_hidden_states, + dim=-1) + + # apply global encoder + hidden_state = self.layernorm_post(hidden_state) + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), dim) + hidden_state = self.global_transformer( + hidden_state, attention_mask=attention_mask)[0] + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = hidden_state[:, :, :slice_index] + + # adding intermediate layer outputs + hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, + num_tiles, num_patches, dim) + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, num_tiles, + num_patches + num_padding_patches, -1) + intermediate_hidden_states = intermediate_hidden_states[:, :, : + slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1) + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], + dim=-1) + return hidden_state + + +class MllamaTextRMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + """ + MllamaTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MllamaTextCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Optional[config_mllama.MllamaTextConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.config = config + self.model_parallel_size = get_tensor_model_parallel_world_size() + self.num_heads = self.config.num_attention_heads + self.num_local_heads = self.num_heads // self.model_parallel_size + self.num_key_value_heads = self.config.num_key_value_heads + self.num_local_key_value_heads = \ + self.num_key_value_heads // self.model_parallel_size + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // self.num_heads + self.layer_idx = layer_idx + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.q_local_size = self.num_local_heads * self.head_dim + self.kv_local_size = self.num_local_key_value_heads * self.head_dim + + # TODO: change to Q/KV separate linear after #7448 is merged + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_key_value_heads, + bias=False, + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + ) + # vllm.model_executor.layers.layernorm.RMSNorm has precision issue, + # use huggingface's instead + self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.scaling = self.head_dim**-0.5 + + self.attn = Attention( + self.num_local_heads, + self.head_dim, + self.scaling, + self.num_local_key_value_heads, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cross_attention_states: Optional[torch.Tensor], + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv_dec, _ = self.qkv_proj(hidden_states) + q, _, _ = qkv_dec.split( + [self.q_local_size, self.kv_local_size, self.kv_local_size], + dim=-1) + if cross_attention_states is None: + k = None + v = None + else: + qkv_enc, _ = self.qkv_proj(cross_attention_states) + _, k, v = qkv_enc.split( + [self.q_local_size, self.kv_local_size, self.kv_local_size], + dim=-1) + k = k.view(-1, self.num_local_key_value_heads, self.head_dim) + v = v.view(-1, self.num_local_key_value_heads, self.head_dim) + k = self.k_norm(k) + q = q.view(-1, self.num_local_heads, self.head_dim) + q = self.q_norm(q) + + output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER_DECODER) + out, _ = self.o_proj(output) + return out + + +class MllamaCrossAttentionDecoderLayer(torch.nn.Module): + """Cross-attention transformer block with tanh-gated attention + and feedforward.""" + + def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \ + -> None: + super().__init__() + self.layer_idx = layer_idx + self.cross_attn = MllamaTextCrossAttention( + config=config, + layer_idx=layer_idx, + ) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) + + self.mlp = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + cross_attention_mask: torch.Tensor, + full_text_row_masked_out_mask: torch.Tensor, + kv_cache: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.cross_attn( + hidden_states=hidden_states, + attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = full_text_row_masked_out_mask * hidden_states + hidden_states = residual + self.cross_attn_attn_gate.tanh( + ) * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = full_text_row_masked_out_mask * hidden_states + hidden_states = residual + self.cross_attn_mlp_gate.tanh( + ) * hidden_states + return hidden_states + + +class MllamaTextModel(nn.Module): + config_class = config_mllama.MllamaTextConfig + base_model_prefix = "model" + + def __init__(self, config: config_mllama.MllamaTextConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig]): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, + config.hidden_size) + self.cross_attention_layers = config.cross_attention_layers + + layers = [] + for layer_idx in range(config.num_hidden_layers): + if layer_idx in self.cross_attention_layers: + layers.append( + MllamaCrossAttentionDecoderLayer(config, layer_idx)) + else: + # TODO: force LlamaDecoderLayer to config.attention_bias=False + layers.append( + LlamaDecoderLayer(config, + cache_config=cache_config, + quant_config=quant_config)) + + self.layers = nn.ModuleList(layers) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, + torch.Tensor]], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + skip_cross_attention: bool, + ) -> torch.Tensor: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): + if not skip_cross_attention: + hidden_states = decoder_layer( + hidden_states=hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask= + full_text_row_masked_out_mask, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + ) + elif isinstance(decoder_layer, LlamaDecoderLayer): + hidden_states, residual = decoder_layer( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + residual=None, + ) + hidden_states = hidden_states + residual + else: + raise ValueError( + f"Unknown decoder layer type {type(decoder_layer)}") + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MllamaForCausalLM(nn.Module): + config_class = config_mllama.MllamaTextConfig + base_model_prefix = "language_model" + _no_split_modules = [ + "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" + ] + + def __init__(self, config: config_mllama.MllamaTextConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig]): + super().__init__() + self.vocab_size = config.vocab_size + self.model = MllamaTextModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, + torch.Tensor]], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + skip_cross_attention: bool, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + skip_cross_attention=skip_cross_attention, + ) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama) +@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama) +@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) +class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): + + def __init__(self, + config: config_mllama.MllamaConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.vocab_size = config.text_config.vocab_size + self.hidden_size = config.text_config.hidden_size + self.max_num_tiles = config.vision_config.max_num_tiles + self.vision_output_dim = config.vision_config.vision_output_dim + self.pad_token_id = \ + config.pad_token_id if config.pad_token_id is not None else -1 + self.image_size = config.vision_config.image_size + + self.vision_model = MllamaVisionModel(config.vision_config) + self.language_model = MllamaForCausalLM( + config.text_config, + cache_config=cache_config, + quant_config=quant_config, + ) + self.multi_modal_projector = nn.Linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=True, + ) + self.logits_processor = LogitsProcessor(config.output_hidden_states, + config.text_config.vocab_size) + self.sampler = Sampler() + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.language_model.lm_head, + hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def _parse_and_validate_image_input(self, **kwargs: object): + # tensor with the same shape will be batched together by + # MultiModalInputs.batch, so pixel_values here can be: + # - List[List[torch.Tensor]]: + # with shape (num_tiles, 3, image_res, image_res) + # - List[torch.Tensor]: + # with shape (num_image, num_tiles, 3, image_res, image_res) + # - torch.Tensor: + # with shape (bs, num_image, num_tiles, 3, image_res, image_res) + pixel_values: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "aspect_ratio_ids", None) + aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "aspect_ratio_mask", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + assert aspect_ratio_ids is not None + assert aspect_ratio_mask is not None + max_num_images = max([len(x[0]) for x in pixel_values]) + if max_num_images == 0: + raise ValueError("No images provided.") + max_num_tiles = max( + max([len(x) for x in y[0]]) for y in pixel_values) + device = self.multi_modal_projector.weight.device + bsz = len(pixel_values) + out_num_tiles = [] + out_images = torch.zeros( + bsz, + max_num_images, + max_num_tiles, + 3, + self.image_size, + self.image_size, + dtype=torch.float32, + device=device, + ) + out_ar_ids = torch.ones(bsz, + max_num_images, + dtype=torch.int64, + device=device) + out_ar_mask = torch.zeros(bsz, + max_num_images, + max_num_tiles, + dtype=torch.int64, + device=device) + for b in range(len(pixel_values)): + _num_tiles = [] + for i in range(len(pixel_values[b][0])): + img = pixel_values[b][0][i] + out_images[b, i, :img.shape[0]] = img + out_ar_ids[b, i] = aspect_ratio_ids[b][0][i] + out_ar_mask[b, i] = aspect_ratio_mask[b][0][i] + _num_tiles.append(img.shape[0]) + out_num_tiles.append(_num_tiles) + + return MllamaImagePixelInputs( + type="pixel_values", + data=out_images, + aspect_ratio_ids=out_ar_ids, + aspect_ratio_mask=out_ar_mask, + ) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def flat_encoder_result(self, cross_attention_states: torch.Tensor, + attn_metadata: AttentionMetadata): + + cross_attention_states_flat = torch.zeros( + sum(attn_metadata.encoder_seq_lens), + cross_attention_states.shape[-1], + device=cross_attention_states.device, + dtype=cross_attention_states.dtype) + start_pos = 0 + for seq_len, vision_token_in_batch in zip( + attn_metadata.encoder_seq_lens, cross_attention_states): + end_pos = start_pos + seq_len + cross_attention_states_flat[ + start_pos:end_pos] = vision_token_in_batch[:seq_len] + start_pos = end_pos + cross_attention_states = cross_attention_states_flat + + full_text_row_masked_out_mask = torch.ones( + (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) + start_pos = 0 + for seq_len, encoder_seq_len in zip( + attn_metadata.seq_lens_tensor.cpu(), + attn_metadata.encoder_seq_lens): + if encoder_seq_len == 0: + full_text_row_masked_out_mask[start_pos:start_pos + + seq_len] = False + start_pos += seq_len + full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( + cross_attention_states.device) + + return cross_attention_states, full_text_row_masked_out_mask + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs: object, + ) -> Union[Tuple, CausalLMOutputWithPast]: + if attn_metadata.num_prefill_tokens > 0 and \ + attn_metadata.num_decode_tokens > 0: + raise ValueError("Chunk prefill not supported") + image_inputs = self._parse_and_validate_image_input(**kwargs) + if image_inputs is None: + cross_attention_mask = None + full_text_row_masked_out_mask = ( + attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to( + input_ids.device) + cross_attention_states = None + skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0 + else: + # NOTE: llama's reference implementation runs vision model on CPU + pixel_values = image_inputs['data'] + aspect_ratio_ids = image_inputs['aspect_ratio_ids'] + aspect_ratio_mask = image_inputs['aspect_ratio_mask'] + cross_attention_states = self.vision_model(pixel_values, + aspect_ratio_ids, + aspect_ratio_mask) + cross_attention_states = self.multi_modal_projector( + cross_attention_states) + + bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape) + cross_attention_states = cross_attention_states.view( + bsz, -1, image_token_dim) + + cross_attention_states, full_text_row_masked_out_mask = \ + self.flat_encoder_result(cross_attention_states, attn_metadata) + skip_cross_attention = False + # TODO: support multi-image by this mask + cross_attention_mask = None + + outputs = self.language_model( + input_ids=input_ids, + positions=positions, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + skip_cross_attention=skip_cross_attention, + ) + + return outputs + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + updated_params = set() + for name, loaded_weight in weights: + if 'patch_embedding.weight' in name: + name = name.replace('patch_embedding.weight', + 'patch_embedding._linear.weight') + loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + updated_params.add(name) + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict.pop(name) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 87d3a4576f332..8bcb38ef241ed 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -54,6 +54,12 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: if isinstance(nested_tensors, torch.Tensor): return nested_tensors + if isinstance(nested_tensors, np.ndarray): + return torch.from_numpy(nested_tensors) + + if isinstance(nested_tensors, (int, float)): + return torch.tensor(nested_tensors) + stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] if not is_list_of(stacked, torch.Tensor, check="all"): # Only tensors (not lists) can be stacked. diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 31b1c3f93411a..d3a230e40477e 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -2,6 +2,7 @@ import torch from PIL import Image +from transformers.image_processing_base import BatchFeature from vllm.config import ModelConfig from vllm.inputs.registry import InputContext @@ -39,6 +40,10 @@ def _default_input_mapper( ) -> MultiModalInputs: model_config = ctx.model_config + # Processed by input processor + if isinstance(data, BatchFeature): + return MultiModalInputs(data.data) + # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): image_processor = self._get_hf_image_processor(model_config) diff --git a/vllm/sequence.py b/vllm/sequence.py index fda7ef87749a1..49a198df045bd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -13,6 +13,7 @@ import msgspec import torch +from vllm.inputs import EncoderDecoderLLMInputs, LLMInputs from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams @@ -21,7 +22,6 @@ from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if TYPE_CHECKING: - from vllm.inputs import LLMInputs from vllm.multimodal.base import MultiModalDataDict VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -471,7 +471,15 @@ def prompt_token_ids(self) -> List[int]: @property def multi_modal_data(self) -> "MultiModalDataDict": - return self.inputs.get("multi_modal_data") or {} + if self.inputs.get("multi_modal_data") and self.inputs.get( + "encoder_multi_modal_data"): + raise ValueError( + "Multi-modal data in both encoder and decoder is not supported." + ) + inputs = self.inputs + return self.inputs.get("multi_modal_data") or (cast( + EncoderDecoderLLMInputs, + inputs).get("encoder_multi_modal_data")) or {} @property def lora_int_id(self) -> int: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1744935d624fb..3871c0cb8b819 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -22,9 +22,10 @@ EAGLEConfig, ExaoneConfig, GraniteConfig, InternVLChatConfig, JAISConfig, MedusaConfig, - MLPSpeculatorConfig, MPTConfig, - NemotronConfig, RWConfig, - SolarConfig, UltravoxConfig) + MllamaConfig, MLPSpeculatorConfig, + MPTConfig, NemotronConfig, + RWConfig, SolarConfig, + UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file @@ -37,6 +38,10 @@ logger = init_logger(__name__) +_CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = { + "mllama": MllamaConfig +} + _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { "chatglm": ChatGLMConfig, "dbrx": DbrxConfig, @@ -55,11 +60,15 @@ # Granite can be removed from here once we have upgraded to # transformers 4.45+ "granite": GraniteConfig, + **_CONFIG_REGISTRY_OVERRIDE_HF } for name, cls in _CONFIG_REGISTRY.items(): with contextlib.suppress(ValueError): - AutoConfig.register(name, cls) + if name in _CONFIG_REGISTRY_OVERRIDE_HF: + AutoConfig.register(name, cls, exist_ok=True) + else: + AutoConfig.register(name, cls) class ConfigFormat(str, enum.Enum): diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index ea4fc8ad21f35..d5b13adb58a0b 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -10,6 +10,7 @@ from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig +from vllm.transformers_utils.configs.mllama import MllamaConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig @@ -26,6 +27,7 @@ "MedusaConfig", "EAGLEConfig", "ExaoneConfig", + "MllamaConfig", "MLPSpeculatorConfig", "NemotronConfig", "SolarConfig", diff --git a/vllm/transformers_utils/configs/mllama.py b/vllm/transformers_utils/configs/mllama.py new file mode 100644 index 0000000000000..49e766d7fa1f4 --- /dev/null +++ b/vllm/transformers_utils/configs/mllama.py @@ -0,0 +1,28 @@ +from transformers.models.mllama import configuration_mllama as mllama_hf_config + + +class MllamaTextConfig(mllama_hf_config.MllamaTextConfig): + ''' + Use this class to override is_encoder_decoder: + - transformers regards mllama as is_encoder_decoder=False + - vllm needs is_encoder_decoder=True to enable cross-attention + ''' + + def __init__( + self, + **kwargs, + ): + super().__init__(**kwargs) + self.is_encoder_decoder = True + + +class MllamaConfig(mllama_hf_config.MllamaConfig): + + def __init__( + self, + text_config=None, + **kwargs, + ): + if isinstance(text_config, dict): + text_config = MllamaTextConfig(**text_config) + super().__init__(text_config=text_config, **kwargs) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index f9fb8d1e103b7..2a2d74382e37a 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -111,7 +111,6 @@ def get_tokenizer( 'encoding and decoding.', FutureWarning, stacklevel=2) - if tokenizer_mode == "mistral": tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), revision=revision) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 709efdc8b9d57..bd716ac3e7ec3 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -18,7 +18,8 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, + MultiModalRegistry) from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata) @@ -52,6 +53,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, "finished_requests_ids": self.finished_requests_ids, + "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -194,6 +196,8 @@ def execute_model( "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_seqlen_agnostic else {} + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -202,6 +206,8 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), **seqlen_agnostic_kwargs) logits = self.model.compute_logits(hidden_or_intermediate_states, @@ -288,8 +294,7 @@ def profile_run(self) -> None: max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( self.model_config) if max_mm_tokens > 0: - raise NotImplementedError( - "Multi-modal encoder-decoder models are not supported yet") + logger.info("Starting profile run for multi-modal models.") batch_size = 0 for group_id in range(max_num_seqs): @@ -297,24 +302,39 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - seq_data, _ = self.input_registry \ - .dummy_data_for_profiling(self.model_config, + decoder_seq_data, decoder_dummy_multi_modal_data \ + = self.input_registry.dummy_data_for_profiling( + self.model_config, seq_len, - self.mm_registry) + self.mm_registry, + is_encoder_data=False) + encoder_seq_data, encoder_dummy_multi_modal_data \ + = self.input_registry.dummy_data_for_profiling( + self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=True) # Having more tokens is over-conservative but otherwise fine - assert len(seq_data.prompt_token_ids) >= seq_len, ( + assert len(decoder_seq_data.prompt_token_ids) >= seq_len, ( f"Expected at least {seq_len} dummy tokens for profiling, " - f"but got: {len(seq_data.prompt_token_ids)}") + f"but got: {len(decoder_seq_data.prompt_token_ids)}") + + assert decoder_dummy_multi_modal_data is None or \ + encoder_dummy_multi_modal_data is None, ( + "Multi-modal data can't be provided in both encoder and decoder" + ) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - seq_data={group_id: seq_data}, + seq_data={group_id: decoder_seq_data}, sampling_params=sampling_params, block_tables=None, - encoder_seq_data=seq_data, + encoder_seq_data=encoder_seq_data, cross_block_table=None, + multi_modal_data=decoder_dummy_multi_modal_data + or encoder_dummy_multi_modal_data, ) seqs.append(seq) diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index a58b80e4f2adb..a07395dfc61d8 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -39,10 +39,6 @@ def assert_enc_dec_mr_supported_scenario( raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) - if enc_dec_mr.model_config.is_multimodal_model: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM']) - if enc_dec_mr.scheduler_config.num_lookahead_slots > 0: raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC']) From e2c6e0a8291126c868b669f631837c7781646fdc Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Wed, 25 Sep 2024 13:29:48 -0700 Subject: [PATCH 42/50] [Doc] Update doc for Transformers 4.45 (#8817) --- docs/source/models/supported_models.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index bf690726a637b..c807617a2c10d 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -242,12 +242,12 @@ Multimodal Language Models * - :code:`LlavaNextVideoForConditionalGeneration` - LLaVA-NeXT-Video - Video - - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. (see note) + - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. - * - :code:`LlavaOnevisionForConditionalGeneration` - LLaVA-Onevision - Image\ :sup:`+` / Video - - :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. (see note) + - :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. - * - :code:`MiniCPMV` - MiniCPM-V @@ -298,7 +298,7 @@ Multimodal Language Models For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 .. note:: - For :code:`LLaVA-NeXT-Video`, :code:`LLaVA-Onevision` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now. + For :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now. This can be installed by running the following command: .. code-block:: bash From 7193774b1ff8603ad5bf4598e5efba0d9a39b436 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 25 Sep 2024 17:46:22 -0400 Subject: [PATCH 43/50] [Misc] Support quantization of MllamaForCausalLM (#8822) --- vllm/model_executor/models/mllama.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index aa868a3b8da28..45d6ad3c0efa5 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -624,6 +624,7 @@ def __init__( self, config: Optional[config_mllama.MllamaTextConfig] = None, layer_idx: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -648,12 +649,14 @@ def __init__( self.num_heads, self.num_key_value_heads, bias=False, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.num_heads * self.head_dim, self.hidden_size, bias=False, input_is_parallel=True, + quant_config=quant_config, ) # vllm.model_executor.layers.layernorm.RMSNorm has precision issue, # use huggingface's instead @@ -708,13 +711,15 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): """Cross-attention transformer block with tanh-gated attention and feedforward.""" - def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \ + def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int, + quant_config: Optional[QuantizationConfig]) \ -> None: super().__init__() self.layer_idx = layer_idx self.cross_attn = MllamaTextCrossAttention( config=config, layer_idx=layer_idx, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, @@ -725,6 +730,7 @@ def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \ hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, + quant_config=quant_config, ) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -780,7 +786,8 @@ def __init__(self, config: config_mllama.MllamaTextConfig, for layer_idx in range(config.num_hidden_layers): if layer_idx in self.cross_attention_layers: layers.append( - MllamaCrossAttentionDecoderLayer(config, layer_idx)) + MllamaCrossAttentionDecoderLayer( + config, layer_idx, quant_config=quant_config)) else: # TODO: force LlamaDecoderLayer to config.attention_bias=False layers.append( From 4bb98f2190aaf408cb063df5184829fb54ee5f81 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 26 Sep 2024 07:45:30 -0700 Subject: [PATCH 44/50] [Misc] Update config loading for Qwen2-VL and remove Granite (#8837) --- docs/source/models/supported_models.rst | 11 +- vllm/model_executor/models/granite.py | 2 +- vllm/model_executor/models/qwen2_vl.py | 5 +- vllm/transformers_utils/config.py | 12 +- vllm/transformers_utils/configs/__init__.py | 8 +- vllm/transformers_utils/configs/granite.py | 199 -------------------- vllm/transformers_utils/configs/qwen2vl.py | 131 +++++++++++++ 7 files changed, 144 insertions(+), 224 deletions(-) delete mode 100644 vllm/transformers_utils/configs/granite.py create mode 100644 vllm/transformers_utils/configs/qwen2vl.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c807617a2c10d..c41903f84910d 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -280,7 +280,7 @@ Multimodal Language Models - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - * - :code:`Qwen2VLForConditionalGeneration` - - Qwen2-VL (see note) + - Qwen2-VL - Image\ :sup:`+` / Video\ :sup:`+` - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - @@ -297,15 +297,6 @@ Multimodal Language Models For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 -.. note:: - For :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now. - This can be installed by running the following command: - - .. code-block:: bash - - pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830 - ----- If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. Otherwise, please refer to :ref:`Adding a New Model ` and :ref:`Enabling Multimodal Inputs ` diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 5f365bbc30670..d4853fd790098 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -25,6 +25,7 @@ import torch from torch import nn +from transformers import GraniteConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig @@ -48,7 +49,6 @@ default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.granite import GraniteConfig from vllm.utils import is_hip from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 889ebc6c2e1ff..f895e693b7107 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -31,12 +31,9 @@ import torch.nn.functional as F from einops import rearrange, repeat from PIL import Image -from transformers import Qwen2VLConfig from transformers.image_utils import (get_image_size, infer_channel_dimension_format, to_numpy_array) -from transformers.models.qwen2_vl.configuration_qwen2_vl import ( - Qwen2VLVisionConfig) from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( make_batched_images, make_batched_videos, smart_resize) @@ -66,6 +63,8 @@ from vllm.multimodal.image import cached_get_image_processor from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SequenceData +from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig, + Qwen2VLVisionConfig) from vllm.transformers_utils.processor import get_processor from vllm.utils import is_cpu diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 3871c0cb8b819..0f20e8d0c8213 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -20,10 +20,10 @@ # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, EAGLEConfig, ExaoneConfig, - GraniteConfig, InternVLChatConfig, - JAISConfig, MedusaConfig, - MllamaConfig, MLPSpeculatorConfig, - MPTConfig, NemotronConfig, + InternVLChatConfig, JAISConfig, + MedusaConfig, MllamaConfig, + MLPSpeculatorConfig, MPTConfig, + NemotronConfig, Qwen2VLConfig, RWConfig, SolarConfig, UltravoxConfig) # yapf: enable @@ -57,9 +57,7 @@ "nemotron": NemotronConfig, "solar": SolarConfig, "ultravox": UltravoxConfig, - # Granite can be removed from here once we have upgraded to - # transformers 4.45+ - "granite": GraniteConfig, + "qwen2_vl": Qwen2VLConfig, **_CONFIG_REGISTRY_OVERRIDE_HF } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index d5b13adb58a0b..462cd964325d2 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -6,7 +6,6 @@ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig -from vllm.transformers_utils.configs.granite import GraniteConfig from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig @@ -14,6 +13,8 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig +from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig, + Qwen2VLVisionConfig) from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig @@ -32,7 +33,6 @@ "NemotronConfig", "SolarConfig", "UltravoxConfig", - # Granite can be removed from here once we have upgraded to - # transformers 4.45+ - "GraniteConfig", + "Qwen2VLConfig", + "Qwen2VLVisionConfig", ] diff --git a/vllm/transformers_utils/configs/granite.py b/vllm/transformers_utils/configs/granite.py deleted file mode 100644 index c12838be5d385..0000000000000 --- a/vllm/transformers_utils/configs/granite.py +++ /dev/null @@ -1,199 +0,0 @@ -# coding=utf-8 -# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Granite model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_rope_utils import rope_config_validation -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -class GraniteConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of - a [`GraniteModel`]. It is used to instantiate an Granite - model according to the specified arguments, defining the model architecture. - Instantiating a configuration with the defaults will yield a similar - configuration to that of the Granite-3B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to - control the model outputs. Read the documentation from [`PretrainedConfig`] - for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the Granite model. Defines the number of - different tokens that can be represented by the `inputs_ids` - passed when calling [`GraniteModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the - Transformer decoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to - implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi - Head Attention (MHA), if `num_key_value_heads=1` the model will use - Multi Query Attention (MQA) otherwise GQA is used. When converting - a multi-head checkpoint to a GQA checkpoint, each group key and - value head should be constructed by meanpooling all the original - heads within that group. For more details checkout - [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not - specified, will default to `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the - decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for - initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values - attentions (not used by all models). Only relevant if - `config.is_decoder=True`. - pad_token_id (`int`, *optional*): - Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2): - End of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE - embeddings. Currently supports two scaling strategies: linear and - dynamic. Their scaling factor must be a float greater than 1. The - expected format is - `{"type": strategy name, "factor": scaling factor}`. - When using this flag, don't update `max_position_embeddings` to - the expected new maximum. See the following thread for more - information on how these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. - This is an experimental feature, subject to breaking API changes - in future versions. - attention_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output - projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - mlp_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in up_proj, down_proj and gate_proj layers - in the MLP layers. - embedding_multiplier (`float`, *optional*, defaults to 1.0): - embedding multiplier - logits_scaling (`float`, *optional*, defaults to 1.0): - divisor for output logits - residual_multiplier (`float`, *optional*, defaults to 1.0): - residual multiplier - attention_multiplier (`float`, *optional*, defaults to 1.0): - attention multiplier - - ```python - >>> from transformers import GraniteModel, GraniteConfig - - >>> # Initializing a Granite granite-3b style configuration - >>> configuration = GraniteConfig() - - >>> # Initializing a model from the granite-7b style configuration - >>> model = GraniteModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "granite" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - tie_word_embeddings=False, - rope_theta=10000.0, - rope_scaling=None, - attention_bias=False, - attention_dropout=0.0, - mlp_bias=False, - embedding_multiplier=1.0, - logits_scaling=1.0, - residual_multiplier=1.0, - attention_multiplier=1.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.mlp_bias = mlp_bias - - self.embedding_multiplier = embedding_multiplier - self.logits_scaling = logits_scaling - self.residual_multiplier = residual_multiplier - self.attention_multiplier = attention_multiplier - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - rope_config_validation(self) diff --git a/vllm/transformers_utils/configs/qwen2vl.py b/vllm/transformers_utils/configs/qwen2vl.py new file mode 100644 index 0000000000000..92dd962790bc8 --- /dev/null +++ b/vllm/transformers_utils/configs/qwen2vl.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2VL model configuration""" + +import os +from typing import Union + +from transformers import PretrainedConfig + + +class Qwen2VLVisionConfig(PretrainedConfig): + model_type = "qwen2_vl" + + def __init__( + self, + depth=32, + embed_dim=1280, + hidden_size=3584, + hidden_act="quick_gelu", + mlp_ratio=4, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, + os.PathLike], + **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + + if config_dict.get("model_type") == "qwen2_vl": + config_dict = config_dict["vision_config"] + + return cls.from_dict(config_dict, **kwargs) + + +class Qwen2VLConfig(PretrainedConfig): + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = Qwen2VLVisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = Qwen2VLVisionConfig() + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # NOTE: the following section from original transformers config + # for Qwen2-VL is commented out to address rope config loading issue + # + # if self.rope_scaling is not None and "type" in self.rope_scaling: + # if self.rope_scaling["type"] == "mrope": + # self.rope_scaling["type"] = "default" + # self.rope_scaling["rope_type"] = self.rope_scaling["type"] + # rope_config_validation(self) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) From 1adaa9a157155b0ed090811a0751604773e0d54c Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Fri, 27 Sep 2024 15:18:55 +0000 Subject: [PATCH 45/50] Add setuptools-scm requirement to requirements-rocm since we don't use requirements-build Changing back PromptType to PromptInputs following refactoring revert --- requirements-rocm.txt | 3 ++- vllm/entrypoints/fast_sync_llm.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 121123611d2da..9e3c4a86cd81d 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -8,4 +8,5 @@ botocore ray >= 2.10.0 peft pytest-asyncio -tensorizer>=2.9.0 \ No newline at end of file +tensorizer>=2.9.0 +setuptools-scm>=8 \ No newline at end of file diff --git a/vllm/entrypoints/fast_sync_llm.py b/vllm/entrypoints/fast_sync_llm.py index eb36d124a89fa..c948fc97feeb9 100644 --- a/vllm/entrypoints/fast_sync_llm.py +++ b/vllm/entrypoints/fast_sync_llm.py @@ -8,7 +8,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.executor.multiproc_gpu_executor import MultiprocessingGPUExecutor from vllm.executor.ray_gpu_executor import RayGPUExecutor -from vllm.inputs import PromptType, TokensPrompt +from vllm.inputs import PromptInputs, TokensPrompt from vllm.logger import init_logger from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -40,7 +40,7 @@ def __init__( def _add_request( self, - inputs: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], request_id: str, ) -> None: From 2d7ab9e5bb74dbeb02a9a047040e945ee24546cf Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Tue, 1 Oct 2024 12:42:07 -0500 Subject: [PATCH 46/50] fix dbrx weight loader (#212) Co-authored-by: Charlie Fu --- vllm/model_executor/models/dbrx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 4fbf0c3270fab..a2ce325fd7999 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -392,7 +392,7 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [( - "w13_weight" if weight_name in ["w1", "v1"] else "w2_weight", + "w13_" if weight_name in ["w1", "v1"] else "w2_", f"mlp.{weight_name}.", ) for weight_name in ["w1", "v1", "w2"]] params_dict = dict(self.named_parameters(remove_duplicate=False)) From 47d6392386e7f7b7b0c2ee695837dc28790c8295 Mon Sep 17 00:00:00 2001 From: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Date: Thu, 3 Oct 2024 10:42:14 -0500 Subject: [PATCH 47/50] Make rpdtracer import only when required (#216) * make rpdtracer import optional * fix rpd_mark * convert rpd_mark to try/except * move rpd_trace import down * move import --- benchmarks/profiling/benchmark_latency.py | 2 +- benchmarks/profiling/benchmark_throughput.py | 2 +- vllm/utils.py | 30 ++++++++++++++------ vllm/worker/worker.py | 3 +- 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/benchmarks/profiling/benchmark_latency.py b/benchmarks/profiling/benchmark_latency.py index 96ebbf760f25b..8a5834ff7b044 100644 --- a/benchmarks/profiling/benchmark_latency.py +++ b/benchmarks/profiling/benchmark_latency.py @@ -9,7 +9,6 @@ import numpy as np import torch -from rpdTracerControl import rpdTracerControl as rpd from tqdm import tqdm from vllm import LLM, SamplingParams @@ -24,6 +23,7 @@ def main(args: argparse.Namespace): @contextmanager def rpd_profiler_context(): + from rpdTracerControl import rpdTracerControl as rpd llm.start_profile() yield llm.stop_profile() diff --git a/benchmarks/profiling/benchmark_throughput.py b/benchmarks/profiling/benchmark_throughput.py index 35d8fa1224661..46b587c16a8ca 100644 --- a/benchmarks/profiling/benchmark_throughput.py +++ b/benchmarks/profiling/benchmark_throughput.py @@ -9,7 +9,6 @@ import torch import uvloop -from rpdTracerControl import rpdTracerControl as rpd from tqdm import tqdm from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) @@ -98,6 +97,7 @@ def run_vllm( @contextmanager def rpd_profiler_context(): + from rpdTracerControl import rpdTracerControl as rpd llm.start_profile() yield llm.stop_profile() diff --git a/vllm/utils.py b/vllm/utils.py index c74fd098ec6ee..3c28cb275f276 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -31,7 +31,6 @@ import torch.types import yaml from packaging.version import Version -from rpdTracerControl import rpdTracerControl from typing_extensions import ParamSpec, TypeIs, assert_never import vllm.envs as envs @@ -208,6 +207,7 @@ def setup_environment_variables(filename): def initialize_rpd_tracer(self, filename, nvtx): try: + from rpdTracerControl import rpdTracerControl rpd_trace.setup_environment_variables(filename) rpdTracerControl.setFilename(name=filename, append=True) return rpdTracerControl(nvtx=nvtx) @@ -233,21 +233,35 @@ def create_file(filename): print(f"An error occurred while creating the filename: {e}") +@lru_cache(maxsize=None) +def is_hipScopedMarker_available(): + try: + from hipScopedMarker import hipScopedMarker + except ImportError: + hipScopedMarker = None + return hipScopedMarker is not None + + class rpd_mark(): def __init__(self, name=None): self.name = name def __call__(self, func): - from hipScopedMarker import hipScopedMarker - @wraps(func) - def inner(*args, **kwds): - marker_name = self.name if self.name else f"{func.__name__}" - with hipScopedMarker(f"{marker_name}"): - return func(*args, **kwds) + if is_hipScopedMarker_available(): + from hipScopedMarker import hipScopedMarker - return inner + @wraps(func) + def inner(*args, **kwds): + marker_name = self.name if self.name else f"{func.__name__}" + with hipScopedMarker(f"{marker_name}"): + return func(*args, **kwds) + + return inner + + else: + return func class Device(enum.Enum): diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9ab00d60fae60..e060fb5e3436c 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -24,7 +24,6 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) -from vllm.utils import rpd_trace from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner @@ -144,6 +143,8 @@ def __init__( logger.info("Profiling enabled. Traces will be saved to: %s", rpd_profiler_trace_dir) + from vllm.utils import rpd_trace + if self.rank == 0: rpd_trace.create_file(filename=str(rpd_profiler_trace_dir)) From 4cb422f45f1e95cff6e296f7c3c3d964479e8652 Mon Sep 17 00:00:00 2001 From: Adrian Abeyta Date: Thu, 3 Oct 2024 10:55:22 -0500 Subject: [PATCH 48/50] Improve profiling setup and documentation, sync benchmarks with main (#218) * Automatically set rpd env var with profile flag * Add readme * Fix lint errors --------- Co-authored-by: AdrianAbeyta --- benchmarks/profiling/README.md | 59 ++++++++++++++++++++ benchmarks/profiling/benchmark_latency.py | 36 +++++++----- benchmarks/profiling/benchmark_throughput.py | 24 +++++--- vllm/utils.py | 3 - 4 files changed, 97 insertions(+), 25 deletions(-) create mode 100644 benchmarks/profiling/README.md diff --git a/benchmarks/profiling/README.md b/benchmarks/profiling/README.md new file mode 100644 index 0000000000000..8e029d8b9c1bf --- /dev/null +++ b/benchmarks/profiling/README.md @@ -0,0 +1,59 @@ +# VLLM Benchmark Profiling + +This profiling directory provides a method to profile VLLM throughput and latency benchmarks using ROCm profiling utilities. + +## 1. Dependencies + +Before using the profiling feature, you need to install the required dependencies: + +### Install ROCm Profile Data + +```bash +git clone -b nvtx_enabled https://github.com/ROCm/rocmProfileData.git +cd rocmProfileData && make && sudo make install +``` + +### Install hipMarker + +```bash +cd rocmProfileData/hipMarker && python3 setup.py install +``` + +## 2. Profiling Benchmarks + +Profiling can be used to monitor the performance of the VLLM benchmarks with ROCm. The key flags used for profiling are: + +- `--profile-rpd`: Profiles the generation process of a single batch. +- `--profile-dir PROFILE_DIR`: Specifies the path to save the profiler output, which can later be visualized using tools like [ui.perfetto.dev](https://ui.perfetto.dev/) or [chrome.tracing](chrome://tracing/). + +### Profiling Using Default Directory + +By default, profiling results are saved in either `vllm_benchmark_latency_result` or `vllm_benchmark_throughput_result`. To run a benchmark and profile it using the default directory, execute: + +```bash +python3 benchmark_throughput.py --input-len {len} --output-len {len} --model {model} --profile-rpd +``` + +### Profiling With a Custom Directory + +You can specify a custom directory for saving profiler outputs by using the `--profile-dir` flag: + +```bash +python3 benchmark_throughput.py --input-len {len} --output-len {len} --model {model} --profile-rpd --profile-dir {/path/to/custom/dir} +``` + +After profiling is complete, an `.rpd` file containing the trace data will be saved to the specified directory. + +## 3. Convert Trace Data to JSON Format + +To view the trace data, it needs to be converted into a format that is compatible with tools like Chrome tracing or Perfetto. + +You can use the `rpd2tracing.py` script in rocmProfileData to convert the `.rpd` file into a JSON file: + +```bash +python3 rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json +``` + +Once the trace is converted, open the `.json` file in [Chrome](chrome://tracing/) or [Perfetto](https://ui.perfetto.dev/) for visualization. + + diff --git a/benchmarks/profiling/benchmark_latency.py b/benchmarks/profiling/benchmark_latency.py index 8a5834ff7b044..07ca021e66ddc 100644 --- a/benchmarks/profiling/benchmark_latency.py +++ b/benchmarks/profiling/benchmark_latency.py @@ -13,7 +13,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -30,8 +30,7 @@ def rpd_profiler_context(): rpd.top_totals() @contextmanager - def torch_profiler_context(profile_dir: Optional[str] = None, - trace_file_name=None): + def torch_profiler_context(profile_dir: Optional[str] = None): p = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, @@ -48,15 +47,27 @@ def torch_profiler_context(profile_dir: Optional[str] = None, print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) - def get_profiling_context(profile_dir: Optional[str] = None, - trace_file_name=None): + def get_profiling_context(profile_dir: Optional[str] = None): if args.profile_torch: - return torch_profiler_context(profile_dir, trace_file_name) + return torch_profiler_context(profile_dir) elif args.profile_rpd: return rpd_profiler_context() else: return nullcontext() + if args.profile_torch or args.profile_rpd: + profile_dir = Path(args.profile_dir + or "./vllm_benchmark_latency_result") + profile_dir.mkdir(parents=True, exist_ok=True) + name = os.path.basename(os.path.normpath(args.model)) + model_trace_name = ( + f"{name}_in_{args.input_len}_out_{args.output_len}_" + f"batch_{args.batch_size}_tp_{args.tensor_parallel_size}") + print(f"Profiling (results will be saved to '{profile_dir}')...") + if args.profile_rpd: + profile_dir /= f"{model_trace_name}.rpd" + os.environ["VLLM_RPD_PROFILER_DIR"] = str(profile_dir) + # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. llm = LLM( @@ -100,19 +111,19 @@ def get_profiling_context(profile_dir: Optional[str] = None, dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_prompts: List[PromptType] = [{ + dummy_inputs: List[PromptInputs] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: - with get_profiling_context(): - llm.generate(dummy_prompts, + with get_profiling_context(profile_dir): + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) else: start_time = time.perf_counter() - llm.generate(dummy_prompts, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() @@ -124,11 +135,6 @@ def run_to_completion(profile_dir: Optional[str] = None): run_to_completion(profile_dir=None) if args.profile_torch or args.profile_rpd: - profile_dir = args.profile_dir - if not profile_dir: - profile_dir = Path(".") / "vllm_benchmark_latency_result" - os.makedirs(profile_dir, exist_ok=True) - print(f"Profiling (results will be saved to '{profile_dir}')...") run_to_completion(profile_dir=profile_dir) return diff --git a/benchmarks/profiling/benchmark_throughput.py b/benchmarks/profiling/benchmark_throughput.py index 46b587c16a8ca..81d126f456028 100644 --- a/benchmarks/profiling/benchmark_throughput.py +++ b/benchmarks/profiling/benchmark_throughput.py @@ -5,6 +5,7 @@ import random import time from contextlib import contextmanager, nullcontext +from pathlib import Path from typing import List, Optional, Tuple import torch @@ -121,15 +122,27 @@ def torch_profiler_context(profile_dir: Optional[str] = None): print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) - def get_profiling_context(profile_dir: Optional[str] = None, - trace_file_name=None): + def get_profiling_context(profile_dir: Optional[str] = None): if args.profile_torch: - return torch_profiler_context(profile_dir, trace_file_name) + return torch_profiler_context(profile_dir) elif args.profile_rpd: return rpd_profiler_context() else: return nullcontext() + if args.profile_torch or args.profile_rpd: + profile_dir = Path(args.profile_dir + or "./vllm_benchmark_throughput_result") + profile_dir.mkdir(parents=True, exist_ok=True) + name = os.path.basename(os.path.normpath(args.model)) + model_trace_name = ( + f"{name}_in_{args.input_len}_out_{args.output_len}_" + f"tp_{args.tensor_parallel_size}") + print(f"Profiling (results will be saved to '{profile_dir}')...") + if args.profile_rpd: + profile_dir /= f"{model_trace_name}.rpd" + os.environ["VLLM_RPD_PROFILER_DIR"] = str(profile_dir) + llm = LLM( model=model, tokenizer=tokenizer, @@ -171,10 +184,7 @@ def get_profiling_context(profile_dir: Optional[str] = None, )) if args.profile_torch or args.profile_rpd: - profile_dir = args.profile_dir - name = os.path.basename(os.path.normpath(args.model)) - model_trace_name = f"{name}_in_{args.input_len}_out_{args.output_len}" - with get_profiling_context(profile_dir, model_trace_name): + with get_profiling_context(profile_dir): llm.generate(prompts, sampling_params, use_tqdm=True) return else: diff --git a/vllm/utils.py b/vllm/utils.py index 3c28cb275f276..144dfc3eeddea 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -163,11 +163,8 @@ def __init__(self, skip=False): self.skip = skip if not self.skip: - if 'RANK' in os.environ or int(os.getenv('WORLD_SIZE', 1)) > 1: - filename = f"{filename}_pid{os.getpid()}" self.name = name self.args = args if args else "" - print(f"filename type {type(filename)}") self.rpd = self.initialize_rpd_tracer(filename, nvtx) def _recreate_cm(self): From 4075b35b5bbe27c729813368aa0a849948e0f33f Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Thu, 3 Oct 2024 18:20:22 -0400 Subject: [PATCH 49/50] Installing the requirements before invoking setup.py since it now imports setuptools_scm (#221) --- Dockerfile.rocm | 1 + 1 file changed, 1 insertion(+) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 1aa2754a6a3fc..e2f21b2b6105c 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -170,6 +170,7 @@ if ls /install/*.whl; then \ fi # Build vLLM RUN cd vllm \ + && python3 -m pip install -r requirements-rocm.txt \ && python3 setup.py clean --all \ && if [ ${USE_CYTHON} -eq "1" ]; then python3 setup_cython.py build_ext --inplace; fi \ && python3 setup.py bdist_wheel --dist-dir=dist From 2550f14a77c84b93045f4603fdcf3bc310164b15 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Date: Fri, 4 Oct 2024 16:03:09 -0700 Subject: [PATCH 50/50] llama3.2 + cross attn test (#220) * llama3.2 + cross attn test * lint issues fix * mypy errors * making yapf happy * cut off WA for tunned gemms * try and catch for non continuous tensor --------- Co-authored-by: Aleksandr Malyshev --- tests/kernels/test_encoder_decoder_attn.py | 4 +- tests/kernels/utils.py | 11 +- vllm/attention/backends/rocm_flash_attn.py | 334 ++++++++++++++++----- vllm/model_executor/layers/tuned_gemm.py | 7 +- vllm/worker/enc_dec_model_runner.py | 18 +- 5 files changed, 280 insertions(+), 94 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index b550a7fdd84f0..f9b15bfb02605 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -21,7 +21,8 @@ from vllm.utils import is_hip # List of support backends for encoder/decoder models -LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] +LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] if not is_hip() \ + else [_Backend.ROCM_FLASH] HEAD_SIZES = [64, 256] @@ -807,7 +808,6 @@ def test_encoder_only( assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) -@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 08004efe9e2f8..d1de0b20be2f7 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -12,8 +12,8 @@ from torch._prims_common import TensorLikeType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType -from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, - make_tensor_with_pad) +from vllm.utils import (STR_BACKEND_ENV_VAR, STR_ROCM_FLASH_ATTN_VAL, + STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. @@ -524,8 +524,13 @@ def make_backend(backend_name: str) -> AttentionBackend: if backend_name == STR_XFORMERS_ATTN_VAL: # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs. from vllm.attention.backends.xformers import XFormersBackend - return XFormersBackend() + + if backend_name == STR_ROCM_FLASH_ATTN_VAL: + from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401 + ROCmFlashAttentionBackend) + return ROCmFlashAttentionBackend + raise AssertionError( f"Unrecognized backend_name {backend_name} for unit test") diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b793ebf46d173..417dbc6d1483c 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -86,6 +86,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): seq_lens: Optional[List[int]] # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| @@ -96,32 +107,38 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): # |-- query_len ---| # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int + max_query_len: Optional[int] = None # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] = None # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool + seq_start_loc: Optional[torch.Tensor] = None # (batch_size,) A tensor of context lengths (tokens that are computed # so far). - context_lens_tensor: Optional[torch.Tensor] + context_lens_tensor: Optional[torch.Tensor] = None _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + @property def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: if self.num_prefills == 0: @@ -132,10 +149,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: assert self.seq_lens is not None assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None assert self.block_tables is not None - assert self.seq_start_loc is not None self._cached_prefill_metadata = ROCmFlashAttentionMetadata( num_prefills=self.num_prefills, @@ -147,12 +161,20 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + query_start_loc=None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) return self._cached_prefill_metadata @property @@ -180,7 +202,12 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", @@ -274,6 +301,97 @@ def _make_alibi_bias(alibi_slopes: torch.Tensor, return attn_biases +def _get_seq_len_block_table_args( + attn_metadata: ROCmFlashAttentionMetadata, + attn_type: AttentionType, +) -> tuple: + ''' + The particular choice of sequence-length + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths + Encoder attn -> select encoder sequence lengths fields + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensors for query and key + * Appropriate max sequence-length scalar + ''' + + partial_prefix_sum = 0 + if attn_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.encoder_seq_lens + ], + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + causal_mask = False + + # No block tables associated with encoder attention + return (query_seq_start_loc, attn_metadata.max_encoder_seq_len, + query_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_lens, causal_mask) + elif attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + assert attn_metadata.seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.seq_lens + ], + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + max_seq_len = attn_metadata.max_prefill_seq_len + causal_mask = True + + return (query_seq_start_loc, max_seq_len, query_seq_start_loc, + max_seq_len, attn_metadata.seq_lens, causal_mask) + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.seq_lens + ], + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + + partial_prefix_sum = 0 + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + key_seq_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.encoder_seq_lens + ], + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + causal_mask = False + + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (query_start_loc, attn_metadata.max_prefill_seq_len, + key_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.seq_lens, causal_mask) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + class ROCmFlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -391,64 +509,104 @@ def forward( ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * ROCmFlashAttentionImpl.forward() may be invoked for both self- and + cross-attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally Returns: shape = [num_tokens, num_heads * head_size] """ - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "ROCmFlashAttentionImpl") - - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None - if kv_cache is not None: + if attn_type != AttentionType.ENCODER and kv_cache is not None: key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - k_scale, - v_scale, - ) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - - output = torch.empty_like(query) + if key is not None and value is not None: + # Reshape the input keys and values and store them in the + # cache. If kv_cache is not provided, the new key and value + # tensors are not cached. This happens during the initial + # memory profiling run. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping + if attn_type != AttentionType.ENCODER_DECODER else + attn_metadata.cross_slot_mapping, + self.kv_cache_dtype, + k_scale, + v_scale, + ) + + if attn_type != AttentionType.ENCODER: + num_prefill_tokens = attn_metadata.num_prefill_tokens + else: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] + # QKV for prefill. query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens + if key is not None and value is not None: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] if prefill_meta := attn_metadata.prefill_metadata: + output = torch.empty_like(query) + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = _get_seq_len_block_table_args( + prefill_meta, attn_type) + # Prompt run. - assert prefill_meta.seq_lens is not None if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the @@ -459,18 +617,18 @@ def forward( attn_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.seq_lens, + seq_lens, make_attn_mask=False) # type: ignore out, _ = self.attn_func( query, key, value, None, - prefill_meta.seq_start_loc, - prefill_meta.seq_start_loc, - prefill_meta.max_prefill_seq_len, - prefill_meta.max_prefill_seq_len, - True, + query_seq_start_loc, + key_seq_start_loc, + query_max_seq_len, + key_max_seq_len, + causal_mask, self.scale, attn_masks[0][None] if attn_masks is not None else None, @@ -494,11 +652,12 @@ def forward( query, key, value, - prefill_meta.seq_lens, - num_tokens, + query_seq_start_loc, + num_prefill_tokens, self.num_heads, self.head_size, self.scale, + causal_mask, attn_masks, ) else: @@ -506,10 +665,10 @@ def forward( q=query, k=key, v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, + cu_seqlens_q=query_seq_start_loc, + cu_seqlens_k=key_seq_start_loc, max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, + max_seqlen_k=key_max_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -545,6 +704,7 @@ def forward( if decode_meta := attn_metadata.decode_metadata: # Decoding run. # Whether to use rocm custom paged attention or not + output = torch.empty_like(decode_query) num_seqs, num_heads, head_size = decode_query.shape block_size = value_cache.shape[3] gqa_ratio = num_heads // self.num_kv_heads @@ -552,7 +712,10 @@ def forward( decode_query.dtype, head_size, block_size, gqa_ratio, decode_meta.max_decode_seq_len) if use_custom: - max_seq_len = decode_meta.max_decode_seq_len + max_seq_len = (decode_meta.max_decode_seq_len + if attn_type != AttentionType.ENCODER_DECODER + else decode_meta.max_encoder_seq_len) + assert max_seq_len is not None max_num_partitions = ( (max_seq_len + _PARTITION_SIZE_ROCM - 1) // _PARTITION_SIZE_ROCM) @@ -573,7 +736,7 @@ def forward( else: out = output ops.paged_attention_rocm( - out, + output[num_prefill_tokens:], exp_sums, max_logits, tmp_output, @@ -582,8 +745,12 @@ def forward( value_cache, self.num_kv_heads, self.scale, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, + decode_meta.block_tables + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, block_size, max_seq_len, self.alibi_slopes, @@ -596,9 +763,15 @@ def forward( decode_query, key_cache, value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - decode_meta.max_decode_seq_len, + decode_meta.block_tables + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, + decode_meta.max_decode_seq_len + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -608,7 +781,7 @@ def forward( ) # Reshape the output tensor. - return output.view(num_tokens, hidden_size) + return output.view(-1, self.num_heads * self.head_size) def _sdpa_attention( @@ -620,6 +793,7 @@ def _sdpa_attention( num_heads: int, head_size: int, scale: float, + is_causal: bool, attn_masks: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: start = 0 @@ -637,7 +811,7 @@ def _sdpa_attention( key[:, start:end, :], value[:, start:end, :], dropout_p=0.0, - is_causal=attn_masks is None, + is_causal=is_causal, attn_mask=attn_masks[i] if attn_masks else None, scale=scale).movedim(query.dim() - 2, 0) output[start:end, :, :] = sub_out diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 7ea1d8d93ea2b..f765b8c39fa6c 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -84,8 +84,11 @@ def mm(self, inp, weights, bias=None): # uses this for linear units. However, sampler # will use torch.matmul with 2 dimensions only if inp.dim() == 3: - inp_view = inp.view(-1, inp.size(-1)) - batched = True + try: + inp_view = inp.view(-1, inp.size(-1)) + batched = True + except RuntimeError: + return F.linear(inp, weights, bias) else: inp_view = inp batched = False diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index bd716ac3e7ec3..4606866bdba52 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -23,7 +23,8 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata) -from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad +from vllm.utils import (STR_NOT_IMPL_ENC_DEC_BACKEND, is_hip, + make_tensor_with_pad) from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata, @@ -120,7 +121,7 @@ def __init__( def _maybe_force_supported_attention_backend(self): ''' - Force vLLM to use the XFormers attention backend, + Force vLLM to use the XFormers or ROCM attention backend, which is currently the only supported option. ''' @@ -138,18 +139,21 @@ def raise_backend_err(): # The user has not already specified an attention backend # override logger.info("EncoderDecoderModelRunner requires " - "XFormers backend; overriding backend " - "auto-selection and forcing XFormers.") - global_force_attn_backend(_Backend.XFORMERS) + "XFormers or ROCM backend; overriding backend " + "auto-selection and forcing XFormers or ROCM.") + global_force_attn_backend( + _Backend.ROCM_FLASH if is_hip() else _Backend.XFORMERS) elif is_forced_by_global: # Backend override enforced by global variable takes # precedence over vLLM backend environment variable. - if maybe_global_forced_backend != _Backend.XFORMERS: + if maybe_global_forced_backend != _Backend.XFORMERS and \ + maybe_global_forced_backend != _Backend.ROCM_FLASH: raise_backend_err() elif is_forced_by_env_var: # Backend override enforced by vLLM backend # environment variable - if maybe_env_var_forced_backend != _Backend.XFORMERS: + if maybe_env_var_forced_backend != _Backend.XFORMERS and \ + maybe_global_forced_backend != _Backend.ROCM_FLASH: raise_backend_err() def _list_to_int32_tensor(