diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md index fbf41eb10a392..d3f5fc5cd4cee 100644 --- a/.buildkite/nightly-benchmarks/README.md +++ b/.buildkite/nightly-benchmarks/README.md @@ -1,15 +1,13 @@ # vLLM benchmark suite - ## Introduction This directory contains two sets of benchmark for vllm. + - Performance benchmark: benchmark vllm's performance under various workload, for **developers** to gain clarity on whether their PR improves/degrades vllm's performance - Nightly benchmark: compare vllm's performance against alternatives (tgi, trt-llm and lmdeploy), for **the public** to know when to choose vllm. - -See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results. - +See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results. ## Performance benchmark quick overview @@ -19,17 +17,14 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performan **For benchmarking developers**: please try your best to constraint the duration of benchmarking to about 1 hr so that it won't take forever to run. - ## Nightly benchmark quick overview -**Benchmarking Coverage**: Fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) on Llama-3 8B, 70B and Mixtral 8x7B. +**Benchmarking Coverage**: Fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) on Llama-3 8B, 70B and Mixtral 8x7B. **Benchmarking engines**: vllm, TGI, trt-llm and lmdeploy. **Benchmarking Duration**: about 3.5hrs. - - ## Trigger the benchmark Performance benchmark will be triggered when: @@ -39,16 +34,11 @@ Performance benchmark will be triggered when: Nightly benchmark will be triggered when: - Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label. - - - ## Performance benchmark details - See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. - -#### Latency test +### Latency test Here is an example of one test inside `latency-tests.json`: @@ -68,23 +58,25 @@ Here is an example of one test inside `latency-tests.json`: ``` In this example: -- The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`. -- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` + +- The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`. +- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly. WARNING: The benchmarking script will save json results by itself, so please do not configure `--output-json` parameter in the json file. +### Throughput test -#### Throughput test The tests are specified in `throughput-tests.json`. The syntax is similar to `latency-tests.json`, except for that the parameters will be fed forward to `benchmark_throughput.py`. The number of this test is also stable -- a slight change on the value of this number might vary the performance numbers by a lot. -#### Serving test +### Serving test + We test the throughput by using `benchmark_serving.py` with request rate = inf to cover the online serving overhead. The corresponding parameters are in `serving-tests.json`, and here is an example: -``` +```json [ { "test_name": "serving_llama8B_tp1_sharegpt", @@ -109,6 +101,7 @@ We test the throughput by using `benchmark_serving.py` with request rate = inf t ``` Inside this example: + - The `test_name` attribute is also a unique identifier for the test. It must start with `serving_`. - The `server-parameters` includes the command line arguments for vLLM server. - The `client-parameters` includes the command line arguments for `benchmark_serving.py`. @@ -118,36 +111,33 @@ The number of this test is less stable compared to the delay and latency benchma WARNING: The benchmarking script will save json results by itself, so please do not configure `--save-results` or other results-saving-related parameters in `serving-tests.json`. -#### Visualizing the results +### Visualizing the results + The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](tests/descriptions.md) with real benchmarking results. You can find the result presented as a table inside the `buildkite/performance-benchmark` job page. If you do not see the table, please wait till the benchmark finish running. The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file. The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking. - - ## Nightly test details See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines. +### Workflow -#### Workflow - -- The [nightly-pipeline.yaml](nightly-pipeline.yaml) specifies the docker containers for different LLM serving engines. +- The [nightly-pipeline.yaml](nightly-pipeline.yaml) specifies the docker containers for different LLM serving engines. - Inside each container, we run [run-nightly-suite.sh](run-nightly-suite.sh), which will probe the serving engine of the current container. - The `run-nightly-suite.sh` will redirect the request to `tests/run-[llm serving engine name]-nightly.sh`, which parses the workload described in [nightly-tests.json](tests/nightly-tests.json) and performs the benchmark. - At last, we run [scripts/plot-nightly-results.py](scripts/plot-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite. -#### Nightly tests +### Nightly tests In [nightly-tests.json](tests/nightly-tests.json), we include the command line arguments for benchmarking commands, together with the benchmarking test cases. The format is highly similar to performance benchmark. -#### Docker containers +### Docker containers The docker containers for benchmarking are specified in `nightly-pipeline.yaml`. WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `tests/run-[llm serving engine name]-nightly.sh`. WARNING: populating `trt-llm` to latest version is not easy, as it requires updating several protobuf files in [tensorrt-demo](https://github.com/neuralmagic/tensorrt-demo.git). - diff --git a/.buildkite/nightly-benchmarks/nightly-annotation.md b/.buildkite/nightly-benchmarks/nightly-annotation.md index 1e33793842bf8..e43ea765f1556 100644 --- a/.buildkite/nightly-benchmarks/nightly-annotation.md +++ b/.buildkite/nightly-benchmarks/nightly-annotation.md @@ -9,20 +9,19 @@ This file contains the downloading link for benchmarking results. Please download the visualization scripts in the post - ## Results reproduction - Find the docker we use in `benchmarking pipeline` - Deploy the docker, and inside the docker: - - Download `nightly-benchmarks.zip`. - - In the same folder, run the following code -``` -export HF_TOKEN= -apt update -apt install -y git -unzip nightly-benchmarks.zip -VLLM_SOURCE_CODE_LOC=./ bash .buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh -``` + - Download `nightly-benchmarks.zip`. + - In the same folder, run the following code: -And the results will be inside `./benchmarks/results`. + ```console + export HF_TOKEN= + apt update + apt install -y git + unzip nightly-benchmarks.zip + VLLM_SOURCE_CODE_LOC=./ bash .buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh + ``` +And the results will be inside `./benchmarks/results`. diff --git a/.buildkite/nightly-benchmarks/nightly-descriptions.md b/.buildkite/nightly-benchmarks/nightly-descriptions.md index 7dec7a0fe0b4e..5f003f42f07c0 100644 --- a/.buildkite/nightly-benchmarks/nightly-descriptions.md +++ b/.buildkite/nightly-benchmarks/nightly-descriptions.md @@ -2,6 +2,7 @@ # Nightly benchmark This benchmark aims to: + - Provide performance clarity: Provide clarity on which one (vllm, tensorrt-llm, lmdeploy and SGLang) leads in performance in what workload. - Be reproducible: one can run the exact same set of benchmarking commands inside the exact same docker by following reproducing instructions. @@ -9,7 +10,6 @@ Latest results: [results link](https://blog.vllm.ai/2024/09/05/perf-update.html) Latest reproduction guilde: [github issue link](https://github.com/vllm-project/vllm/issues/8176) - ## Setup - Docker images: @@ -33,7 +33,7 @@ Latest reproduction guilde: [github issue link](https://github.com/vllm-project/ - Queries are randomly sampled, and arrival patterns are determined via Poisson process, but all with fixed random seed. - Evaluation metrics: Throughput (higher the better), TTFT (time to the first token, lower the better), ITL (inter-token latency, lower the better). -# Known issues +## Known issues - TRT-LLM crashes with Llama 3.1 8B [issue](https://github.com/NVIDIA/TensorRT-LLM/issues/2105). -- TGI does not support `ignore-eos` flag. \ No newline at end of file +- TGI does not support `ignore-eos` flag. diff --git a/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md b/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md index da32d1f073cea..cacaef986c658 100644 --- a/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md +++ b/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md @@ -7,10 +7,8 @@ - Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. - Evaluation metrics: end-to-end latency (mean, median, p99). - {latency_tests_markdown_table} - ## Throughput tests - Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). @@ -19,10 +17,8 @@ - Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. - Evaluation metrics: throughput. - {throughput_tests_markdown_table} - ## Serving tests - Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). @@ -33,13 +29,11 @@ - We also added a speculative decoding test for llama-3 70B, under QPS 2 - Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99). - {serving_tests_markdown_table} - ## json version of the benchmarking tables -This section contains the data of the markdown tables above in JSON format. +This section contains the data of the markdown tables above in JSON format. You can load the benchmarking tables into pandas dataframes as follows: ```python @@ -54,9 +48,9 @@ serving_results = pd.DataFrame.from_dict(benchmarking_results["serving"]) ``` The json string for all benchmarking tables: + ```json {benchmarking_results_in_json_string} ``` You can also check the raw experiment data in the Artifact tab of the Buildkite page. - diff --git a/.buildkite/run-neuron-test.sh b/.buildkite/run-neuron-test.sh index 1ad77cf50f612..55c374fcc33de 100644 --- a/.buildkite/run-neuron-test.sh +++ b/.buildkite/run-neuron-test.sh @@ -29,9 +29,6 @@ if [ -f /tmp/neuron-docker-build-timestamp ]; then docker image prune -f # Remove unused volumes / force the system prune for old images as well. docker volume prune -f && docker system prune -f - # Remove huggingface model artifacts and compiler cache - rm -rf "${HF_MOUNT:?}/*" - rm -rf "${NEURON_COMPILE_CACHE_MOUNT:?}/*" echo "$current_time" > /tmp/neuron-docker-build-timestamp fi else diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7ef40564c5bd2..948eab97ffae2 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -128,7 +128,7 @@ steps: - tests/spec_decode/e2e/test_integration_dist_tp4 - tests/compile - examples/offline_inference/rlhf.py - - examples/offline_inference/ray_placement.py + - examples/offline_inference/rlhf_colocate.py commands: - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py @@ -137,7 +137,7 @@ steps: # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - python3 ../examples/offline_inference/rlhf.py - - RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/ray_placement.py + - RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py - label: Metrics, Tracing Test # 10min num_gpus: 2 @@ -195,6 +195,9 @@ steps: # TODO: accuracy does not match, whether setting # VLLM_USE_FLASHINFER_SAMPLER or not on H100. - VLLM_USE_V1=1 pytest -v -s v1/e2e + # Integration test for streaming correctness (requires special branch). + - pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api + - pytest -v -s entrypoints/openai/test_accuracy.py::test_lm_eval_accuracy_v1_engine - label: Examples Test # 25min working_dir: "/vllm-workspace/examples" diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 51a73c857ccb2..a20c5baf895c1 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,4 +2,5 @@ FILL IN THE PR DESCRIPTION HERE FIX #xxxx (*link existing issues this PR will resolve*) -**BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html ** + +**BEFORE SUBMITTING, PLEASE READ ** diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4568efcbba211..352eb2df01b98 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,36 +8,41 @@ repos: - id: yapf args: [--in-place, --verbose] additional_dependencies: [toml] # TODO: Remove when yapf is upgraded + exclude: 'vllm/third_party/.*' - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.9.3 hooks: - id: ruff args: [--output-format, github] + exclude: 'vllm/third_party/.*' - repo: https://github.com/codespell-project/codespell rev: v2.4.0 hooks: - id: codespell - exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*' + exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*|vllm/third_party/.*' - repo: https://github.com/PyCQA/isort rev: 5.13.2 hooks: - id: isort + exclude: 'vllm/third_party/.*' - repo: https://github.com/pre-commit/mirrors-clang-format rev: v19.1.7 hooks: - id: clang-format - exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))' + exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*' types_or: [c++, cuda] args: [--style=file, --verbose] - repo: https://github.com/jackdewinter/pymarkdown rev: v0.9.27 hooks: - id: pymarkdown - files: docs/.* + args: [fix] + exclude: 'vllm/third_party/.*' - repo: https://github.com/rhysd/actionlint rev: v1.7.7 hooks: - id: actionlint + exclude: 'vllm/third_party/.*' - repo: local hooks: - id: mypy-local @@ -47,6 +52,7 @@ repos: types: [python] additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests] stages: [pre-commit] # Don't run in CI + exclude: 'vllm/third_party/.*' - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.9 entry: tools/mypy.sh 1 "3.9" @@ -54,6 +60,7 @@ repos: types: [python] additional_dependencies: *mypy_deps stages: [manual] # Only run in CI + exclude: 'vllm/third_party/.*' - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.10 entry: tools/mypy.sh 1 "3.10" @@ -61,6 +68,7 @@ repos: types: [python] additional_dependencies: *mypy_deps stages: [manual] # Only run in CI + exclude: 'vllm/third_party/.*' - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.11 entry: tools/mypy.sh 1 "3.11" @@ -68,6 +76,7 @@ repos: types: [python] additional_dependencies: *mypy_deps stages: [manual] # Only run in CI + exclude: 'vllm/third_party/.*' - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.12 entry: tools/mypy.sh 1 "3.12" @@ -75,16 +84,19 @@ repos: types: [python] additional_dependencies: *mypy_deps stages: [manual] # Only run in CI + exclude: 'vllm/third_party/.*' - id: shellcheck name: Lint shell scripts entry: tools/shellcheck.sh language: script types: [shell] + exclude: 'vllm/third_party/.*' - id: png-lint name: Lint PNG exports from excalidraw entry: tools/png-lint.sh language: script types: [png] + exclude: 'vllm/third_party/.*' - id: signoff-commit name: Sign-off Commit entry: bash @@ -97,14 +109,27 @@ repos: language: system verbose: true stages: [commit-msg] + exclude: 'vllm/third_party/.*' - id: check-spdx-header name: Check SPDX headers entry: python tools/check_spdx_header.py language: python types: [python] + exclude: 'vllm/third_party/.*' - id: suggestion name: Suggestion entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."' language: system verbose: true pass_filenames: false + exclude: 'vllm/third_party/.*' + - id: check-filenames + name: Check for spaces in all filenames + entry: bash + args: + - -c + - 'git ls-files | grep " " && echo "Filenames should not contain spaces!" && exit 1 || exit 0' + language: system + always_run: true + pass_filenames: false + exclude: 'vllm/third_party/.*' diff --git a/CMakeLists.txt b/CMakeLists.txt index c823c9ff895c3..b99061dfde4fd 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -581,7 +581,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG d4e09037abf588af1ec47d0e966b237ee376876c + GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 1a9596841cc65..5268ff135c9d0 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -125,4 +125,3 @@ Community Impact Guidelines were inspired by For answers to common questions about this code of conduct, see the [Contributor Covenant FAQ](https://www.contributor-covenant.org/faq). Translations are available at [Contributor Covenant translations](https://www.contributor-covenant.org/translations). - diff --git a/Dockerfile.neuron b/Dockerfile.neuron index e9cb82889decd..27658d836d988 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -23,10 +23,12 @@ WORKDIR ${APP_MOUNT}/vllm RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas RUN python3 -m pip install sentencepiece transformers==4.45.2 -U -RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U RUN python3 -m pip install neuronx-cc==2.16.345.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com -U RUN python3 -m pip install pytest +# uninstall transformers-neuronx package explicitly to avoid version conflict +RUN python3 -m pip uninstall -y transformers-neuronx + COPY . . ARG GIT_REPO_CHECK=0 RUN --mount=type=bind,source=.git,target=.git \ @@ -43,6 +45,10 @@ RUN --mount=type=bind,source=.git,target=.git \ # install development dependencies (for testing) RUN python3 -m pip install -e tests/vllm_test_utils +# install transformers-neuronx package as an optional dependencies (for V0) +# FIXME: `--no-deps` argument is temporarily added to resolve transformers package version conflict +RUN python3 -m pip install transformers-neuronx==0.13.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U --no-deps + # overwrite entrypoint to run bash script RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py diff --git a/Dockerfile.rocm_base b/Dockerfile.rocm_base index 5bbe98b0c2204..e33e73b303098 100644 --- a/Dockerfile.rocm_base +++ b/Dockerfile.rocm_base @@ -6,7 +6,7 @@ ARG RCCL_BRANCH="648a58d" ARG RCCL_REPO="https://github.com/ROCm/rccl" ARG TRITON_BRANCH="e5be006" ARG TRITON_REPO="https://github.com/triton-lang/triton.git" -ARG PYTORCH_BRANCH="8d4926e" +ARG PYTORCH_BRANCH="3a585126" ARG PYTORCH_VISION_BRANCH="v0.19.1" ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" diff --git a/README.md b/README.md index cd0b1c517fdbd..f22a1f9c5c80e 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,12 @@ Easy, fast, and cheap LLM serving for everyone --- +We are excited to invite you to our Menlo Park meetup with Meta, evening of Thursday, February 27! Meta engineers will discuss the improvements on top of vLLM, and vLLM contributors will share updates from the v0.7.x series of releases. [Register Now](https://lu.ma/h7g3kuj9) + +--- + *Latest News* 🔥 + - [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). - [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing). - [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone! @@ -33,7 +38,9 @@ Easy, fast, and cheap LLM serving for everyone - [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai). --- + ## About + vLLM is a fast and easy-to-use library for LLM inference and serving. Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry. @@ -127,6 +134,7 @@ We also have an official fundraising venue through [OpenCollective](https://open ## Citation If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180): + ```bibtex @inproceedings{kwon2023efficient, title={Efficient Memory Management for Large Language Model Serving with PagedAttention}, @@ -138,11 +146,11 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs ## Contact Us -* For technical questions and feature requests, please use Github issues or discussions. -* For discussing with fellow users and coordinating contributions and development, please use Slack. -* For security disclosures, please use Github's security advisory feature. -* For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu. +- For technical questions and feature requests, please use Github issues or discussions. +- For discussing with fellow users and coordinating contributions and development, please use Slack. +- For security disclosures, please use Github's security advisory feature. +- For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu. ## Media Kit -* If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit). +- If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit). diff --git a/benchmarks/README.md b/benchmarks/README.md index 2aa4a285021f1..367ef93457f9f 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -3,6 +3,7 @@ ## Downloading the ShareGPT dataset You can download the dataset by running: + ```bash wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json ``` @@ -11,9 +12,18 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r The json file refers to several image datasets (coco, llava, etc.). The benchmark scripts will ignore a datapoint if the referred image is missing. + ```bash wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/resolve/main/sharegpt4v_instruct_gpt4-vision_cap100k.json mkdir coco -p wget http://images.cocodataset.org/zips/train2017.zip -O coco/train2017.zip unzip coco/train2017.zip -d coco/ ``` + +# Downloading the BurstGPT dataset + +You can download the BurstGPT v1.1 dataset by running: + +```bash +wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv +``` diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index e934d228f7fd4..0c892384236bc 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -38,6 +38,7 @@ from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple import numpy as np +import pandas as pd from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, RequestFuncOutput) from datasets import load_dataset @@ -131,6 +132,35 @@ def sample_sharegpt_requests( return filtered_dataset +def sample_burstgpt_requests( + dataset_path: str, + num_requests: int, + random_seed: int, + tokenizer: PreTrainedTokenizerBase, +) -> List[Tuple[str, int, int, None]]: + df = pd.read_csv(dataset_path) + gpt4_df = df[df["Model"] == "GPT-4"] + # Remove the failed requests (i.e., response length is 0) + gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0] + # Randomly sample num_requests from the dataset + if num_requests <= len(gpt4_df): + gpt4_df = gpt4_df.sample(n=num_requests, random_state=random_seed) + else: + gpt4_df = gpt4_df.sample(n=num_requests, + random_state=random_seed, + replace=True) + # Convert the dataframe to a list of tuples + dataset = gpt4_df.values.tolist() + input_requests = [] + for i in range(num_requests): + input_len = int(dataset[i][2]) + output_len = int(dataset[i][3]) + prompt = tokenizer.decode([(i + j) % tokenizer.vocab_size + for j in range(input_len)]) + input_requests.append((prompt, input_len, output_len, None)) + return input_requests + + def sample_sonnet_requests( dataset_path: str, num_requests: int, @@ -537,6 +567,7 @@ async def benchmark( ignore_eos: bool, goodput_config_dict: Dict[str, float], max_concurrency: Optional[int], + lora_modules: Optional[List[str]], ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -562,6 +593,7 @@ async def benchmark( multi_modal_content=test_mm_content, ignore_eos=ignore_eos, ) + test_output = await request_func(request_func_input=test_input) if not test_output.success: raise ValueError( @@ -570,6 +602,11 @@ async def benchmark( else: print("Initial test run completed. Starting main benchmark run...") + if lora_modules: + # For each input request, choose a LoRA module at random. + lora_modules = iter( + [random.choice(lora_modules) for _ in range(len(input_requests))]) + if profile: print("Starting profiler...") profile_input = RequestFuncInput(model=model_id, @@ -616,8 +653,13 @@ async def limited_request_func(request_func_input, pbar): tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate, burstiness): prompt, prompt_len, output_len, mm_content = request - request_func_input = RequestFuncInput(model=model_id, - model_name=model_name, + req_model_id, req_model_name = model_id, model_name + if lora_modules: + req_lora_module = next(lora_modules) + req_model_id, req_model_name = req_lora_module, req_lora_module + + request_func_input = RequestFuncInput(model=req_model_id, + model_name=req_model_name, prompt=prompt, api_url=api_url, prompt_len=prompt_len, @@ -818,6 +860,14 @@ def main(args: argparse.Namespace): fixed_output_len=args.sharegpt_output_len, ) + elif args.dataset_name == "burstgpt": + input_requests = sample_burstgpt_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + random_seed=args.seed, + tokenizer=tokenizer, + ) + elif args.dataset_name == "sonnet": # Do not format the prompt, pass to message directly if args.backend == "openai-chat": @@ -900,6 +950,7 @@ def main(args: argparse.Namespace): ignore_eos=args.ignore_eos, goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, )) # Save config and results to json @@ -982,7 +1033,7 @@ def main(args: argparse.Namespace): "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "sonnet", "random", "hf"], + choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"], help="Name of the dataset to benchmark on.", ) parser.add_argument("--dataset-path", @@ -1237,5 +1288,12 @@ def main(args: argparse.Namespace): "If not specified, the model name will be the " "same as the ``--model`` argument. ") + parser.add_argument("--lora-modules", + nargs='+', + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.") + args = parser.parse_args() main(args) diff --git a/csrc/cumem_allocator.cpp b/csrc/cumem_allocator.cpp index e8555d853b7ac..fab6ca36d422e 100644 --- a/csrc/cumem_allocator.cpp +++ b/csrc/cumem_allocator.cpp @@ -12,15 +12,21 @@ extern "C" { #include #include -#define CUDA_CHECK(condition) \ - do { \ - CUresult error = condition; \ - if (error != 0) { \ - char* error_string; \ - cuGetErrorString(error, (const char**)&error_string); \ - std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" \ - << __LINE__ << std::endl; \ - } \ +char error_msg[10240]; // 10KB buffer to store error messages +CUresult no_error = CUresult(0); +CUresult error_code = no_error; // store error code + +#define CUDA_CHECK(condition) \ + do { \ + CUresult error = condition; \ + if (error != 0) { \ + error_code = error; \ + char* error_string; \ + cuGetErrorString(error, (const char**)&error_string); \ + snprintf(error_msg, sizeof(error_msg), "CUDA Error: %s at %s:%d", \ + error_string, __FILE__, __LINE__); \ + std::cerr << error_msg << std::endl; \ + } \ } while (0) // Global references to Python callables @@ -54,14 +60,22 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem, // Allocate memory using cuMemCreate CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0)); + if (error_code != 0) { + return; + } CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0)); - + if (error_code != 0) { + return; + } CUmemAccessDesc accessDesc = {}; accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; accessDesc.location.id = device; accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1)); + if (error_code != 0) { + return; + } // std::cout << "create_and_map: device=" << device << ", size=" << size << ", // d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; } @@ -73,7 +87,13 @@ void unmap_and_release(unsigned long long device, ssize_t size, // ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; ensure_context(device); CUDA_CHECK(cuMemUnmap(d_mem, size)); + if (error_code != 0) { + return; + } CUDA_CHECK(cuMemRelease(*p_memHandle)); + if (error_code != 0) { + return; + } } PyObject* create_tuple_from_c_integers(unsigned long long a, @@ -121,12 +141,16 @@ void* my_malloc(ssize_t size, int device, CUstream stream) { size_t granularity; CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); - + if (error_code != 0) { + return nullptr; + } size_t alignedSize = ((size + granularity - 1) / granularity) * granularity; CUdeviceptr d_mem; CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0)); - + if (error_code != 0) { + return nullptr; + } // allocate the CUmemGenericAllocationHandle CUmemGenericAllocationHandle* p_memHandle = (CUmemGenericAllocationHandle*)malloc( @@ -208,6 +232,9 @@ void my_free(void* ptr, ssize_t size, int device, CUstream stream) { // free address and the handle CUDA_CHECK(cuMemAddressFree(d_mem, size)); + if (error_code != 0) { + return; + } free(p_memHandle); } @@ -258,6 +285,12 @@ static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) { unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle); + if (error_code != 0) { + error_code = no_error; + PyErr_SetString(PyExc_RuntimeError, error_msg); + return nullptr; + } + Py_RETURN_NONE; } @@ -282,6 +315,12 @@ static PyObject* python_create_and_map(PyObject* self, PyObject* args) { create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle); + if (error_code != 0) { + error_code = no_error; + PyErr_SetString(PyExc_RuntimeError, error_msg); + return nullptr; + } + Py_RETURN_NONE; } diff --git a/csrc/quantization/cutlass_w8a8/Epilogues.md b/csrc/quantization/cutlass_w8a8/Epilogues.md index aae04157b10de..a30e1fdf3ac77 100644 --- a/csrc/quantization/cutlass_w8a8/Epilogues.md +++ b/csrc/quantization/cutlass_w8a8/Epilogues.md @@ -1,17 +1,19 @@ # CUTLASS Epilogues ## Introduction -This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs. + +This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs. Currently, we only support symmetric quantization for weights, and symmetric and asymmetric quantization for activations. Both can be quantized per-tensor or per-channel (weights) / per-token (activations). There are 4 epilogues: -1. ScaledEpilogue: symmetric quantization for activations, no bias. -1. ScaledEpilogueBias: symmetric quantization for activations, supports bias. -1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias. -1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias. + +1. `ScaledEpilogue`: symmetric quantization for activations, no bias. +1. `ScaledEpilogueBias`: symmetric quantization for activations, supports bias. +1. `ScaledEpilogueAzp`: asymmetric per-tensor quantization for activations, supports bias. +1. `ScaledEpilogueAzpPerToken`: asymmetric per-token quantization for activations, supports bias. We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size. Instead, if no bias is passed, the epilogue will use 0 as the bias. @@ -26,12 +28,15 @@ If $` \widehat X `$ is the quantized $` X `$, our matrices become the following ```math A = s_a (\widehat A - J_a z_a) ``` + ```math B = s_b \widehat B ``` + ```math D = A B + C ``` + ```math D = s_a s_b \widehat D + C ``` @@ -48,9 +53,11 @@ Expanding further, we can calculate $` \widehat D `$ as follows: ```math A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B ``` + ```math A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right) ``` + ```math \widehat D = \widehat A \widehat B - z_a J_a \widehat B ``` @@ -61,16 +68,19 @@ Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of ## Epilogues -### ScaledEpilogue +### `ScaledEpilogue` + This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$. The output of the GEMM is: ```math \widehat D = \widehat A \widehat B ``` + ```math D = s_a s_b \widehat D ``` + ```math D = s_a s_b \widehat A \widehat B ``` @@ -79,44 +89,51 @@ Epilogue parameters: - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). -### ScaledEpilogueBias +### `ScaledEpilogueBias` + This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$. The output of the GEMM is: ```math \widehat D = \widehat A \widehat B ``` + ```math D = s_a s_b \widehat D + C ``` + ```math D = s_a s_b \widehat A \widehat B + C ``` - Epilogue parameters: + - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). - `bias` is the bias, is always per-channel (row-vector). -### ScaledEpilogueAzp +### `ScaledEpilogueAzp` + This epilogue computes the asymmetric per-tensor quantization for activations with bias. The output of the GEMM is: ```math \widehat D = \widehat A \widehat B - z_a J_a \widehat B ``` + ```math D = s_a s_b \widehat D + C ``` + ```math D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C ``` -Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$. +Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$. That is precomputed and stored in `azp_with_adj` as a row-vector. Epilogue parameters: + - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - Generally this will be per-tensor as the zero-points are per-tensor. - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). @@ -125,13 +142,15 @@ Epilogue parameters: To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel. -### ScaledEpilogueAzpPerToken +### `ScaledEpilogueAzpPerToken` + This epilogue computes the asymmetric per-token quantization for activations with bias. The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector. That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$. Epilogue parameters: + - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - Generally this will be per-token as the zero-points are per-token. - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). @@ -142,6 +161,7 @@ Epilogue parameters: To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel. The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM): -``` + +```math out = scale_a * scale_b * (Dq - azp_adj * azp) + bias ``` diff --git a/csrc/quantization/machete/Readme.md b/csrc/quantization/machete/Readme.md index 9ddf8da993b0e..6ffb2416b73b2 100644 --- a/csrc/quantization/machete/Readme.md +++ b/csrc/quantization/machete/Readme.md @@ -6,25 +6,25 @@ Machete is a spiritual successor to the Marlin kernel but optimized for Hopper a Machete effectively performs -``` +```python scale_type = w_s.dtype compute_type = a.dtype out = (w_q.to(scale_type) * w_s - w_z.to(scale_type)) @ a ``` -Where `w_q` is a quantized weight matrix, `w_s` is the quantization scales, and +Where `w_q` is a quantized weight matrix, `w_s` is the quantization scales, and `w_z` is the quantization zeropoints. -> **_NOTE:_** `w_z` is added after the scales so we can +> **_NOTE:_** `w_z` is added after the scales so we can use FMA operations, but this means they must have the scales pre-applied if the -supplied zeropoints assume that they will be subtracted before the scales are +supplied zeropoints assume that they will be subtracted before the scales are applied. ## API The main optimization within Machete is prepacking the weight matrix to more closely match the tensor core layouts, allowing for wider shared memory loads when loading the weight matrix. This means that the weight matrix must be prepacked before calling `machete_gemm`. The flow looks something like: -``` +```python from vllm import _custom_ops as ops ... @@ -40,6 +40,6 @@ output = ops.machete_gemm( ## Code Generation -Since Machete is based on Cutlass, we can generate multiple type pairs and different tile shapes using the same kernel template. We generate multiple instantiations of this template using `generate.py`. +Since Machete is based on Cutlass, we can generate multiple type pairs and different tile shapes using the same kernel template. We generate multiple instantiations of this template using `generate.py`. -New type pairs (`TypeConfig`s) can be appended to `impl_configs` (in `generate()`), and these will get automatically generated (assuming they can be supported without issues). For each `TypeConfig`, you must also provide an `ImplConfig`, which bundles a `TypeConfig` with a list of `ScheduleConfig`s, `Specialization`s, and a default heuristic. The `ScheduleConfig`s (which contain info on tile shapes, tile scheduler, etc.) can perform differently for different problem shapes, and there is almost never one `ScheduleConfig` that works well for all problem shapes, so it is generally beneficial to generate different `ScheduleConfig`s for different potential problem shapes. This is where the heuristic comes in. For each `TypeConfig`, a default heuristic should be provided. This maps different problem shapes to different `ScheduleConfig`s and is used when the user does not provide the `schedule` parameter to `machete_gemm`. The `Specialization`s define what feature combinations to generate, i.e., `with_zeropoints`, `with_scales`, etc. We can reduce compile times and the final binary size by limiting the set of feature combinations we generate. \ No newline at end of file +New type pairs (`TypeConfig`s) can be appended to `impl_configs` (in `generate()`), and these will get automatically generated (assuming they can be supported without issues). For each `TypeConfig`, you must also provide an `ImplConfig`, which bundles a `TypeConfig` with a list of `ScheduleConfig`s, `Specialization`s, and a default heuristic. The `ScheduleConfig`s (which contain info on tile shapes, tile scheduler, etc.) can perform differently for different problem shapes, and there is almost never one `ScheduleConfig` that works well for all problem shapes, so it is generally beneficial to generate different `ScheduleConfig`s for different potential problem shapes. This is where the heuristic comes in. For each `TypeConfig`, a default heuristic should be provided. This maps different problem shapes to different `ScheduleConfig`s and is used when the user does not provide the `schedule` parameter to `machete_gemm`. The `Specialization`s define what feature combinations to generate, i.e., `with_zeropoints`, `with_scales`, etc. We can reduce compile times and the final binary size by limiting the set of feature combinations we generate. diff --git a/docs/seed_parameter_behavior.md b/docs/seed_parameter_behavior.md new file mode 100644 index 0000000000000..ff17525cf8e2f --- /dev/null +++ b/docs/seed_parameter_behavior.md @@ -0,0 +1,51 @@ +# Seed Parameter Behavior in vLLM + +## Overview + +The `seed` parameter in vLLM is used to control the random states for various random number generators. This parameter can affect the behavior of random operations in user code, especially when working with models in vLLM. + +## Default Behavior + +By default, the `seed` parameter is set to `None`. When the `seed` parameter is `None`, the global random states for `random`, `np.random`, and `torch.manual_seed` are not set. This means that the random operations will behave as expected, without any fixed random states. + +## Specifying a Seed + +If a specific seed value is provided, the global random states for `random`, `np.random`, and `torch.manual_seed` will be set accordingly. This can be useful for reproducibility, as it ensures that the random operations produce the same results across multiple runs. + +## Example Usage + +### Without Specifying a Seed + +```python +import random +from vllm import LLM + +# Initialize a vLLM model without specifying a seed +model = LLM(model="Qwen/Qwen2.5-0.5B-Instruct") + +# Try generating random numbers +print(random.randint(0, 100)) # Outputs different numbers across runs +``` + +### Specifying a Seed + +```python +import random +from vllm import LLM + +# Initialize a vLLM model with a specific seed +model = LLM(model="Qwen/Qwen2.5-0.5B-Instruct", seed=42) + +# Try generating random numbers +print(random.randint(0, 100)) # Outputs the same number across runs +``` + +## Important Notes + +- If the `seed` parameter is not specified, the behavior of global random states remains unaffected. +- If a specific seed value is provided, the global random states for `random`, `np.random`, and `torch.manual_seed` will be set to that value. +- This behavior can be useful for reproducibility but may lead to non-intuitive behavior if the user is not explicitly aware of it. + +## Conclusion + +Understanding the behavior of the `seed` parameter in vLLM is crucial for ensuring the expected behavior of random operations in your code. By default, the `seed` parameter is set to `None`, which means that the global random states are not affected. However, specifying a seed value can help achieve reproducibility in your experiments. diff --git a/docs/source/features/compatibility_matrix.md b/docs/source/features/compatibility_matrix.md index b0018ebccf5ba..ee5db70c7d5c8 100644 --- a/docs/source/features/compatibility_matrix.md +++ b/docs/source/features/compatibility_matrix.md @@ -297,7 +297,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ * ? - * [✗](gh-issue:7968>) + * [✗](gh-issue:7968) * ? * ✅ * diff --git a/docs/source/features/tool_calling.md b/docs/source/features/tool_calling.md index 027ddb6d5eda3..85a9e03739863 100644 --- a/docs/source/features/tool_calling.md +++ b/docs/source/features/tool_calling.md @@ -1,6 +1,6 @@ # Tool Calling -vLLM currently supports named function calling, as well as the `auto` and `none` options for the `tool_choice` field in the chat completion API. The `tool_choice` option `required` is **not yet supported** but on the roadmap. +vLLM currently supports named function calling, as well as the `auto` and `none` options for the `tool_choice` field in the chat completion API. The `tool_choice` option `required` is **not yet supported** but [on the roadmap](gh-issue:13002). ## Quickstart diff --git a/docs/source/getting_started/installation/gpu/rocm.inc.md b/docs/source/getting_started/installation/gpu/rocm.inc.md index c8fd11415cfda..7004313c90f32 100644 --- a/docs/source/getting_started/installation/gpu/rocm.inc.md +++ b/docs/source/getting_started/installation/gpu/rocm.inc.md @@ -1,6 +1,6 @@ # Installation -vLLM supports AMD GPUs with ROCm 6.2. +vLLM supports AMD GPUs with ROCm 6.3. :::{attention} There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source. @@ -9,7 +9,7 @@ There are no pre-built wheels for this device, so you must either use the pre-bu ## Requirements - GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100) -- ROCm 6.2 +- ROCm 6.3 ## Set up using Python @@ -24,9 +24,15 @@ Currently, there are no pre-built ROCm wheels. - [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html) - [PyTorch](https://pytorch.org/) - For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0`, `rocm/pytorch-nightly`. + For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.3_ubuntu24.04_py3.12_pytorch_release_2.4.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3. - Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/) + Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/). Example: + + ```console + # Install PyTorch + $ pip uninstall torch -y + $ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/rocm6.3 + ``` 1. Install [Triton flash attention for ROCm](https://github.com/ROCm/triton) @@ -37,7 +43,7 @@ Currently, there are no pre-built ROCm wheels. pip uninstall -y triton git clone https://github.com/OpenAI/triton.git cd triton - git checkout e192dba + git checkout e5be006 cd python pip3 install . cd ../.. @@ -49,15 +55,15 @@ Currently, there are no pre-built ROCm wheels. 2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention/tree/ck_tile) - Install ROCm's flash attention (v2.5.9.post1) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support) + Install ROCm's flash attention (v2.7.2) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support) Alternatively, wheels intended for vLLM use can be accessed under the releases. - For example, for ROCm 6.2, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`. + For example, for ROCm 6.3, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`. ```console git clone https://github.com/ROCm/flash-attention.git cd flash-attention - git checkout 3cea2fb + git checkout b7d29fb git submodule update --init GPU_ARCHS="gfx90a" python3 setup.py install cd .. @@ -67,20 +73,16 @@ Currently, there are no pre-built ROCm wheels. You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) ::: -3. Build vLLM. For example, vLLM on ROCM 6.2 can be built with the following steps: +3. Build vLLM. For example, vLLM on ROCM 6.3 can be built with the following steps: ```bash $ pip install --upgrade pip - # Install PyTorch - $ pip uninstall torch -y - $ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/rocm6.2 - # Build & install AMD SMI $ pip install /opt/rocm/share/amd_smi # Install dependencies - $ pip install --upgrade numba scipy huggingface-hub[cli] + $ pip install --upgrade numba scipy huggingface-hub[cli,hf_transfer] setuptools_scm $ pip install "numpy<2" $ pip install -r requirements-rocm.txt @@ -91,12 +93,11 @@ Currently, there are no pre-built ROCm wheels. This may take 5-10 minutes. Currently, `pip install .` does not work for ROCm installation. - :::{tip} - - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. - - Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support. - - To use CK flash-attention or PyTorch naive attention, please use this flag `export VLLM_USE_TRITON_FLASH_ATTN=0` to turn off triton flash attention. - - The ROCm version of PyTorch, ideally, should match the ROCm driver version. + - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. + - Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support. + - To use CK flash-attention or PyTorch naive attention, please use this flag `export VLLM_USE_TRITON_FLASH_ATTN=0` to turn off triton flash attention. + - The ROCm version of PyTorch, ideally, should match the ROCm driver version. ::: :::{tip} @@ -104,7 +105,7 @@ Currently, there are no pre-built ROCm wheels. For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization). ::: -## Set up using Docker +## Set up using Docker (Recommended) ### Pre-built images @@ -120,7 +121,12 @@ for instructions on how to use this prebuilt docker image. Building the Docker image from source is the recommended way to use vLLM with ROCm. -First, build a docker image from and launch a docker container from the image. +#### (Optional) Build an image with ROCm software stack + +Build a docker image from which setup ROCm software stack needed by the vLLM. +**This step is optional as this rocm_base image is usually prebuilt and store at [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev) under tag `rocm/vllm-dev:base` to speed up user experience.** +If you choose to build this rocm_base image yourself, the steps are as follows. + It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: ```console @@ -131,7 +137,26 @@ It is important that the user kicks off the docker build using buildkit. Either } ``` - uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches. +To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default: + +```console +DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm_base -t rocm/vllm-dev:base . +``` + +#### Build an image with vLLM + +First, build a docker image from and launch a docker container from the image. +It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: + +```console +{ + "features": { + "buildkit": true + } +} +``` + + uses ROCm 6.3 by default, but also supports ROCm 5.7, 6.0, 6.1, and 6.2, in older vLLM branches. It provides flexibility to customize the build of docker image using the following arguments: - `BASE_IMAGE`: specifies the base image used when running `docker build`. The default value `rocm/vllm-dev:base` is an image published and maintained by AMD. It is being built using @@ -141,13 +166,13 @@ It provides flexibility to customize the build of docker image using the followi Their values can be passed in when running `docker build` with `--build-arg` options. -To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default: +To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default: ```console DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm . ``` -To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should pick the alternative base image: +To build vllm on ROCm 6.3 for Radeon RX7900 series (gfx1100), you should pick the alternative base image: ```console DOCKER_BUILDKIT=1 docker build --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" -f Dockerfile.rocm -t vllm-rocm . diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 32f3e9deff671..55b3f52356cd0 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -429,7 +429,7 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ - * `TeleChat2ForCausalLM` * TeleChat2 - * `TeleAI/TeleChat2-3B`, `TeleAI/TeleChat2-7B`, `TeleAI/TeleChat2-35B`, etc. + * `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. * ✅︎ * ✅︎ - * `XverseForCausalLM` @@ -719,7 +719,7 @@ See [this page](#generative-models) for more information on how to use generativ * `THUDM/glm-4v-9b` etc. * ✅︎ * ✅︎ - * + * ✅︎ - * `H2OVLChatModel` * H2OVL * T + IE+ @@ -856,7 +856,7 @@ See [this page](#generative-models) for more information on how to use generativ - * `UltravoxModel` * Ultravox * T + AE+ - * `fixie-ai/ultravox-v0_3` + * `fixie-ai/ultravox-v0_5-llama-3_2-1b` * ✅︎ * ✅︎ * ✅︎ diff --git a/docs/source/serving/engine_args.md b/docs/source/serving/engine_args.md index 827c25b50522f..f4587b94edeaf 100644 --- a/docs/source/serving/engine_args.md +++ b/docs/source/serving/engine_args.md @@ -4,7 +4,7 @@ Below, you can find an explanation of every engine argument for vLLM: - + ```{eval-rst} .. argparse:: :module: vllm.engine.arg_utils @@ -17,7 +17,7 @@ Below, you can find an explanation of every engine argument for vLLM: Below are the additional arguments related to the asynchronous engine: - + ```{eval-rst} .. argparse:: :module: vllm.engine.arg_utils diff --git a/docs/source/serving/multimodal_inputs.md b/docs/source/serving/multimodal_inputs.md index 217b531e83788..ade59e3773839 100644 --- a/docs/source/serving/multimodal_inputs.md +++ b/docs/source/serving/multimodal_inputs.md @@ -359,12 +359,12 @@ export VLLM_VIDEO_FETCH_TIMEOUT= ### Audio Audio input is supported according to [OpenAI Audio API](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in). -Here is a simple example using Ultravox-v0.3. +Here is a simple example using Ultravox-v0.5-1B. First, launch the OpenAI-compatible server: ```bash -vllm serve fixie-ai/ultravox-v0_3 +vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b ``` Then, you can use the OpenAI client as follows: diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 707ca9f878961..3e3034a02f0f1 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -24,9 +24,9 @@ # Unless specified, these settings have been tested to work on a single L4. -# Ultravox 0.3 +# Ultravox 0.5-1B def run_ultravox(question: str, audio_count: int): - model_name = "fixie-ai/ultravox-v0_3" + model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b" tokenizer = AutoTokenizer.from_pretrained(model_name) messages = [{ diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py new file mode 100644 index 0000000000000..2e41cabaccafc --- /dev/null +++ b/examples/offline_inference/disaggregated_prefill.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the example usage of disaggregated prefilling +We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode), +and then transfer the KV cache between them. +""" +import os +import time +from multiprocessing import Event, Process + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + + +def run_prefill(prefill_done): + # We use GPU 0 for prefill node. + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + # The prefill node receives two requests, while the decode node receives + # three requests. So the decode node will only receive the KV Cache for + # requests 1 and 3. The decode node will use the KV Cache of requests 1 + # and 3 and do prefilling on request 2. + prompts = [ + "Hello, my name is", + # "Hi, your name is", + # The decode node will actually "prefill" this request. + "Tell me a very long story", + ] + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + # Using PyNcclConnector to transmit KV caches between vLLM instances. + # This instance is the prefill node (kv_producer, rank 0). + # The number of parallel instances for KV cache transfer is set to 2, + # as required for PyNcclConnector. + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' + ) + + # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB + # memory. You may need to adjust the value to fit your GPU. + llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", + kv_transfer_config=ktc, + max_model_len=2000, + gpu_memory_utilization=0.8) + + llm.generate(prompts, sampling_params) + print("Prefill node is finished.") + prefill_done.set() + + # To keep the prefill node running in case the decode node is not done; + # otherwise, the script might exit prematurely, causing incomplete decoding. + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("Script stopped by user.") + + +def run_decode(prefill_done): + # We use GPU 1 for decode node. + os.environ["CUDA_VISIBLE_DEVICES"] = "1" + + prompts = [ + "Hello, my name is", + "Hi, your name is", + "Tell me a very long story", + ] + sampling_params = SamplingParams(temperature=0, top_p=0.95) + + # Using PyNcclConnector to transmit KV caches between vLLM instances. + # This instance is the decode node (kv_consumer, rank 1). + # The number of parallel instances for KV cache transfer is set to 2, + # as required for PyNcclConnector. + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' + ) + + # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB + # memory. You may need to adjust the value to fit your GPU. + llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", + kv_transfer_config=ktc, + max_model_len=2000, + gpu_memory_utilization=0.8) + + # Wait for the producer to start the pipe + print("Waiting for prefill node to finish...") + prefill_done.wait() + + # At this point when the prefill_done is set, the kv-cache should have been + # transferred to this decode node, so we can start decoding. + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == "__main__": + prefill_done = Event() + prefill_process = Process(target=run_prefill, args=(prefill_done, )) + decode_process = Process(target=run_decode, args=(prefill_done, )) + + # Start prefill node + prefill_process.start() + + # Start decode node + decode_process.start() + + # Terminate the prefill node when decode is finished + decode_process.join() + prefill_process.terminate() diff --git a/examples/offline_inference/openai/openai_batch.md b/examples/offline_inference/openai/openai_batch.md index 953e6ef130f18..d271573aa96fc 100644 --- a/examples/offline_inference/openai/openai_batch.md +++ b/examples/offline_inference/openai/openai_batch.md @@ -5,50 +5,49 @@ This is a guide to performing batch inference using the OpenAI batch file format ``` ## File Format - + The OpenAI batch file format consists of a series of json objects on new lines. - + [See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/openai/openai_example_batch.jsonl) - + Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details. - + ```{note} We currently support `/v1/chat/completions`, `/v1/embeddings`, and `/v1/score` endpoints (completions coming soon). ``` - + ## Pre-requisites * The examples in this document use `meta-llama/Meta-Llama-3-8B-Instruct`. - Create a [user access token](https://huggingface.co/docs/hub/en/security-tokens) - Install the token on your machine (Run `huggingface-cli login`). - Get access to the gated model by [visiting the model card](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and agreeing to the terms and conditions. - - + ## Example 1: Running with a local file ### Step 1: Create your batch file To follow along with this example, you can download the example batch, or create your own batch file in your working directory. -``` +```console wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl ``` Once you've created your batch file it should look like this -``` +```console $ cat offline_inference/openai/openai_example_batch.jsonl {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} ``` ### Step 2: Run the batch - + The batch running tool is designed to be used from the command line. You can run the batch with the following command, which will write its results to a file called `results.jsonl` -``` +```console python -m vllm.entrypoints.openai.run_batch -i offline_inference/openai/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct ``` @@ -56,7 +55,7 @@ python -m vllm.entrypoints.openai.run_batch -i offline_inference/openai/openai_e You should now have your results at `results.jsonl`. You can check your results by running `cat results.jsonl` -``` +```console $ cat results.jsonl {"id":"vllm-383d1c59835645aeb2e07d004d62a826","custom_id":"request-1","response":{"id":"cmpl-61c020e54b964d5a98fa7527bfcdd378","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Hello! It's great to meet you! I'm here to help with any questions or tasks you may have. What's on your mind today?"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":25,"total_tokens":56,"completion_tokens":31}},"error":null} {"id":"vllm-42e3d09b14b04568afa3f1797751a267","custom_id":"request-2","response":{"id":"cmpl-f44d049f6b3a42d4b2d7850bb1e31bcc","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"*silence*"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":27,"total_tokens":32,"completion_tokens":5}},"error":null} @@ -68,7 +67,7 @@ The batch runner supports remote input and output urls that are accessible via h For example, to run against our example input file located at `https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl`, you can run -``` +```console python -m vllm.entrypoints.openai.run_batch -i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct ``` @@ -80,7 +79,7 @@ To integrate with cloud blob storage, we recommend using presigned urls. ### Additional prerequisites -* [Create an S3 bucket](https://docs.aws.amazon.com/AmazonS3/latest/userguide/creating-bucket.html). +* [Create an S3 bucket](https://docs.aws.amazon.com/AmazonS3/latest/userguide/creating-bucket.html). * The `awscli` package (Run `pip install awscli`) to configure your credentials and interactively use s3. - [Configure your credentials](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-quickstart.html). * The `boto3` python package (Run `pip install boto3`) to generate presigned urls. @@ -89,13 +88,13 @@ To integrate with cloud blob storage, we recommend using presigned urls. To follow along with this example, you can download the example batch, or create your own batch file in your working directory. -``` +```console wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl ``` Once you've created your batch file it should look like this -``` +```console $ cat offline_inference/openai/openai_example_batch.jsonl {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} @@ -103,7 +102,7 @@ $ cat offline_inference/openai/openai_example_batch.jsonl Now upload your batch file to your S3 bucket. -``` +```console aws s3 cp offline_inference/openai/openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl ``` @@ -111,9 +110,9 @@ aws s3 cp offline_inference/openai/openai_example_batch.jsonl s3://MY_BUCKET/MY_ Presigned urls can only be generated via the SDK. You can run the following python script to generate your presigned urls. Be sure to replace the `MY_BUCKET`, `MY_INPUT_FILE.jsonl`, and `MY_OUTPUT_FILE.jsonl` placeholders with your bucket and file names. -(The script is adapted from https://github.com/awsdocs/aws-doc-sdk-examples/blob/main/python/example_code/s3/s3_basics/presigned_url.py) +(The script is adapted from ) -``` +```python import boto3 from botocore.exceptions import ClientError @@ -149,7 +148,7 @@ print(f"{output_url=}") This script should output -``` +```text input_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091' output_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091' ``` @@ -158,7 +157,7 @@ output_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AW You can now run the batch runner, using the urls generated in the previous section. -``` +```console python -m vllm.entrypoints.openai.run_batch \ -i "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \ -o "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \ @@ -169,7 +168,7 @@ python -m vllm.entrypoints.openai.run_batch \ Your results are now on S3. You can view them in your terminal by running -``` +```console aws s3 cp s3://MY_BUCKET/MY_OUTPUT_FILE.jsonl - ``` @@ -180,10 +179,10 @@ aws s3 cp s3://MY_BUCKET/MY_OUTPUT_FILE.jsonl - * Ensure you are using `vllm >= 0.5.5`. ### Step 1: Create your batch file - + Add embedding requests to your batch file. The following is an example: - -``` + +```text {"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}} {"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are an unhelpful assistant."}} ``` @@ -198,7 +197,7 @@ You can run the batch using the same command as in earlier examples. You can check your results by running `cat results.jsonl` -``` +```console $ cat results.jsonl {"id":"vllm-db0f71f7dec244e6bce530e0b4ef908b","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-3580bf4d4ae54d52b67eee266a6eab20","body":{"id":"embd-33ac2efa7996430184461f2e38529746","object":"list","created":444647,"model":"intfloat/e5-mistral-7b-instruct","data":[{"index":0,"object":"embedding","embedding":[0.016204833984375,0.0092010498046875,0.0018358230590820312,-0.0028228759765625,0.001422882080078125,-0.0031147003173828125,...]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0}}},"error":null} ... @@ -211,10 +210,10 @@ $ cat results.jsonl * Ensure you are using `vllm >= 0.7.0`. ### Step 1: Create your batch file - + Add score requests to your batch file. The following is an example: - -``` + +```text {"custom_id": "request-1", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} {"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} ``` @@ -229,7 +228,7 @@ You can run the batch using the same command as in earlier examples. You can check your results by running `cat results.jsonl` -``` +```console $ cat results.jsonl {"id":"vllm-f87c5c4539184f618e555744a2965987","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-806ab64512e44071b37d3f7ccd291413","body":{"id":"score-4ee45236897b4d29907d49b01298cdb1","object":"list","created":1737847944,"model":"BAAI/bge-reranker-v2-m3","data":[{"index":0,"object":"score","score":0.0010900497436523438},{"index":1,"object":"score","score":1.0}],"usage":{"prompt_tokens":37,"total_tokens":37,"completion_tokens":0,"prompt_tokens_details":null}}},"error":null} {"id":"vllm-41990c51a26d4fac8419077f12871099","custom_id":"request-2","response":{"status_code":200,"request_id":"vllm-batch-73ce66379026482699f81974e14e1e99","body":{"id":"score-13f2ffe6ba40460fbf9f7f00ad667d75","object":"list","created":1737847944,"model":"BAAI/bge-reranker-v2-m3","data":[{"index":0,"object":"score","score":0.001094818115234375},{"index":1,"object":"score","score":1.0}],"usage":{"prompt_tokens":37,"total_tokens":37,"completion_tokens":0,"prompt_tokens_details":null}}},"error":null} diff --git a/examples/offline_inference/profiling_tpu/README.md b/examples/offline_inference/profiling_tpu/README.md index 08efa63dc1021..6595efec43779 100644 --- a/examples/offline_inference/profiling_tpu/README.md +++ b/examples/offline_inference/profiling_tpu/README.md @@ -29,7 +29,6 @@ python3 profiling.py \ --profile-result-dir profiles ``` - ### Generate Decode Trace This example runs Llama 3.1 70B with a batch of 32 requests where each has 1 input token and 128 output tokens. This is set up in attempt to profile just the 32 decodes running in parallel by having an extremely small prefill of 1 token and setting `VLLM_TPU_PROFILE_DELAY_MS=1000` to skip the first second of inference (hopefully prefill). @@ -51,17 +50,18 @@ python3 profiling.py \ --max-model-len 2048 --tensor-parallel-size 8 ``` - ## Visualizing the profiles Once you have collected your profiles with this script, you can visualize them using [TensorBoard](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm). Here are most likely the dependencies you need to install: + ```bash pip install tensorflow-cpu tensorboard-plugin-profile etils importlib_resources ``` Then you just need to point TensorBoard to the directory where you saved the profiles and visit `http://localhost:6006/` in your browser: + ```bash tensorboard --logdir profiles/ --port 6006 -``` \ No newline at end of file +``` diff --git a/examples/offline_inference/ray_placement.py b/examples/offline_inference/rlhf_colocate.py similarity index 56% rename from examples/offline_inference/ray_placement.py rename to examples/offline_inference/rlhf_colocate.py index cd801a3c0c858..b921bc71feb99 100644 --- a/examples/offline_inference/ray_placement.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -1,13 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 """ -a simple demonstration to show how to control -the placement of the vLLM workers with Ray. -The key is to set VLLM_RAY_PER_WORKER_GPUS and -VLLM_RAY_BUNDLE_INDICES properly. +a simple demonstration to show how to co-locate +vLLM worker with training actors on the same GPUs, +for RLHF-like applications. +The key points: +- Control the placement of the vLLM workers with Ray, by setting + VLLM_RAY_PER_WORKER_GPUS and VLLM_RAY_BUNDLE_INDICES properly. +- Use cuda-ipc to pass tensors, since NCCL does not work when we have + multiple processes on the same GPU. """ import os import ray +import torch from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -19,7 +24,33 @@ class MyWorker(Worker): def report_device_id(self) -> str: from vllm.platforms import current_platform - return current_platform.get_device_uuid(self.device.index) + self.device_uuid = current_platform.get_device_uuid(self.device.index) + return self.device_uuid + + def update_weights_from_ipc_handles(self, ipc_handles): + handles = ipc_handles[self.device_uuid] + device_id = self.device.index + weights = [] + for name, handle in handles.items(): + func, args = handle + list_args = list(args) + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + tensor = func(*list_args) + weights.append((name, tensor)) + self.model_runner.model.load_weights(weights=weights) + torch.cuda.synchronize() + + def check_weights_changed(self): + """ + Check if the weights are updated to 0. + """ + weights_updated = True + for name, p in self.model_runner.model.named_parameters(): + weights_updated = weights_updated and torch.allclose( + p, torch.zeros_like(p)) + return weights_updated class MyLLM(LLM): @@ -40,12 +71,32 @@ def __init__(self, *args, bundle_indices: list, **kwargs): class RayTrainingActor: - def report_device_id(self) -> str: + def __init__(self): + # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs + from transformers import AutoModelForCausalLM + self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") + self.model.to("cuda:0") + for name, p in self.model.named_parameters(): + p.data.zero_() + torch.cuda.synchronize() # the argument for get_device_uuid is the index # of the GPU in the visible devices. - # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs from vllm.platforms import current_platform - return current_platform.get_device_uuid(0) + self.device_uuid = current_platform.get_device_uuid(0) + + def report_device_id(self) -> str: + return self.device_uuid + + def get_weight_ipc_handles(self): + from torch.multiprocessing.reductions import reduce_tensor + data = {} + for name, p in self.model.named_parameters(): + # the training actor might only have a subset of the weights + # and need to all-gather the weights from all the actors. + # for demonstration, here we assume all training actors have + # the full weights. + data[name] = reduce_tensor(p.detach()) + return {self.device_uuid: data} # ray manages 4 GPUs @@ -78,6 +129,8 @@ def report_device_id(self) -> str: ), )(RayTrainingActor).remote() training_actors.append(training_actor) + +for bundle_index, training_actor in enumerate(training_actors): device_id = ray.get(training_actor.report_device_id.remote()) print(f"training actor {bundle_index} is on {device_id}") training_actor_device_ids.append(device_id) @@ -119,3 +172,18 @@ def report_device_id(self) -> str: # the last two training actors should be # on the same GPUs as the second inference engine assert training_actor_device_ids[2:] == inference_engine_device_ids[1] + +print("gather all the IPC handles from the training actors") +ipc_handles = {} +for actor in training_actors: + ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote())) + +print("update the weights of the inference engines") +for llm in inference_engines: + ray.get( + llm.collective_rpc.remote("update_weights_from_ipc_handles", + args=(ipc_handles, ))) +print("check if the weights are updated") +for llm in inference_engines: + assert ray.get( + llm.collective_rpc.remote("check_weights_changed", args=tuple())) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 436c36570599a..9a4183106cff9 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -106,7 +106,9 @@ def run_glm4v(question: str, modality: str): trust_remote_code=True, enforce_eager=True, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) - prompt = question + prompt = f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ + {question}<|assistant|>" + stop_token_ids = [151329, 151336, 151338] return llm, prompt, stop_token_ids diff --git a/examples/online_serving/chart-helm/README.md b/examples/online_serving/chart-helm/README.md index 6aa126d4fd22c..bfe81121d1fd4 100644 --- a/examples/online_serving/chart-helm/README.md +++ b/examples/online_serving/chart-helm/README.md @@ -18,4 +18,4 @@ This directory contains a Helm chart for deploying the vllm application. The cha - templates/poddisruptionbudget.yaml: Template for Pod Disruption Budget. - templates/pvc.yaml: Template for Persistent Volume Claims. - templates/secrets.yaml: Template for Kubernetes Secrets. -- templates/service.yaml: Template for creating Services. \ No newline at end of file +- templates/service.yaml: Template for creating Services. diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py index d5f798a8dae62..ecfcf05a90d16 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py @@ -12,7 +12,7 @@ --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 (audio inference with Ultravox) -vllm serve fixie-ai/ultravox-v0_3 --max-model-len 4096 +vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096 """ import base64 diff --git a/examples/online_serving/openai_chat_completion_with_reasoning.py b/examples/online_serving/openai_chat_completion_with_reasoning.py index a88c8adb55c28..b5dbed1205d35 100644 --- a/examples/online_serving/openai_chat_completion_with_reasoning.py +++ b/examples/online_serving/openai_chat_completion_with_reasoning.py @@ -36,8 +36,8 @@ reasoning_content = response.choices[0].message.reasoning_content content = response.choices[0].message.content -print("reasoning_content:", reasoning_content) -print("content:", content) +print("reasoning_content for Round 1:", reasoning_content) +print("content for Round 1:", content) # Round 2 messages.append({"role": "assistant", "content": content}) @@ -50,5 +50,5 @@ reasoning_content = response.choices[0].message.reasoning_content content = response.choices[0].message.content -print("reasoning_content:", reasoning_content) -print("content:", content) +print("reasoning_content for Round 2:", reasoning_content) +print("content for Round 2:", content) diff --git a/examples/online_serving/openai_chat_embedding_client_for_multimodal.py b/examples/online_serving/openai_chat_embedding_client_for_multimodal.py index f49d7a228191c..e410620378a52 100644 --- a/examples/online_serving/openai_chat_embedding_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_embedding_client_for_multimodal.py @@ -44,7 +44,7 @@ def vlm2vec(): def dse_qwen2_vl(inp: dict): # Embedding an Image - if inp["dtype"] == "image": + if inp["type"] == "image": messages = [{ "role": "user", @@ -113,10 +113,10 @@ def dse_qwen2_vl(inp: dict): vlm2vec() elif args.model == "dse_qwen2_vl": dse_qwen2_vl({ - "dtye": "image", + "type": "image", "image_url": image_url, }) dse_qwen2_vl({ - "dtype": "text", + "type": "text", "content": "What is the weather like today?", }) diff --git a/examples/online_serving/opentelemetry/Otel.md b/examples/online_serving/opentelemetry/Otel.md index 96d1f96bfa144..af00340079745 100644 --- a/examples/online_serving/opentelemetry/Otel.md +++ b/examples/online_serving/opentelemetry/Otel.md @@ -1,7 +1,8 @@ # Setup OpenTelemetry POC 1. Install OpenTelemetry packages: - ``` + + ```console pip install \ 'opentelemetry-sdk>=1.26.0,<1.27.0' \ 'opentelemetry-api>=1.26.0,<1.27.0' \ @@ -10,7 +11,8 @@ ``` 1. Start Jaeger in a docker container: - ``` + + ```console # From: https://www.jaegertracing.io/docs/1.57/getting-started/ docker run --rm --name jaeger \ -e COLLECTOR_ZIPKIN_HOST_PORT=:9411 \ @@ -28,19 +30,23 @@ ``` 1. In a new shell, export Jaeger IP: - ``` + + ```console export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger) export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=grpc://$JAEGER_IP:4317 ``` + Then set vLLM's service name for OpenTelemetry, enable insecure connections to Jaeger and run vLLM: - ``` + + ```console export OTEL_SERVICE_NAME="vllm-server" export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true vllm serve facebook/opt-125m --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" ``` 1. In a new shell, send requests with trace context from a dummy client - ``` + + ```console export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger) export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=grpc://$JAEGER_IP:4317 export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true @@ -48,7 +54,7 @@ python dummy_client.py ``` -1. Open Jaeger webui: http://localhost:16686/ +1. Open Jaeger webui: In the search pane, select `vllm-server` service and hit `Find Traces`. You should get a list of traces, one for each request. ![Traces](https://i.imgur.com/GYHhFjo.png) @@ -57,26 +63,32 @@ ![Spans details](https://i.imgur.com/OPf6CBL.png) ## Exporter Protocol + OpenTelemetry supports either `grpc` or `http/protobuf` as the transport protocol for trace data in the exporter. By default, `grpc` is used. To set `http/protobuf` as the protocol, configure the `OTEL_EXPORTER_OTLP_TRACES_PROTOCOL` environment variable as follows: -``` + +```console export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=http://$JAEGER_IP:4318/v1/traces vllm serve facebook/opt-125m --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" ``` ## Instrumentation of FastAPI + OpenTelemetry allows automatic instrumentation of FastAPI. + 1. Install the instrumentation library - ``` + + ```console pip install opentelemetry-instrumentation-fastapi ``` 1. Run vLLM with `opentelemetry-instrument` - ``` + + ```console opentelemetry-instrument vllm serve facebook/opt-125m ``` 1. Send a request to vLLM and find its trace in Jaeger. It should contain spans from FastAPI. -![FastAPI Spans](https://i.imgur.com/hywvoOJ.png) \ No newline at end of file +![FastAPI Spans](https://i.imgur.com/hywvoOJ.png) diff --git a/examples/online_serving/prometheus_grafana/README.md b/examples/online_serving/prometheus_grafana/README.md index 4a85f953b0b4c..6df9594516664 100644 --- a/examples/online_serving/prometheus_grafana/README.md +++ b/examples/online_serving/prometheus_grafana/README.md @@ -1,14 +1,16 @@ -# Prometheus and Grafana +# Prometheus and Grafana -This is a simple example that shows you how to connect vLLM metric logging to the Prometheus/Grafana stack. For this example, we launch Prometheus and Grafana via Docker. You can checkout other methods through [Prometheus](https://prometheus.io/) and [Grafana](https://grafana.com/) websites. +This is a simple example that shows you how to connect vLLM metric logging to the Prometheus/Grafana stack. For this example, we launch Prometheus and Grafana via Docker. You can checkout other methods through [Prometheus](https://prometheus.io/) and [Grafana](https://grafana.com/) websites. + +Install: -Install: - [`docker`](https://docs.docker.com/engine/install/) - [`docker compose`](https://docs.docker.com/compose/install/linux/#install-using-the-repository) ## Launch Prometheus metric logging is enabled by default in the OpenAI-compatible server. Launch via the entrypoint: + ```bash vllm serve mistralai/Mistral-7B-v0.1 \ --max-model-len 2048 \ @@ -16,11 +18,13 @@ vllm serve mistralai/Mistral-7B-v0.1 \ ``` Launch Prometheus and Grafana servers with `docker compose`: + ```bash docker compose up ``` Submit some sample requests to the server: + ```bash wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json @@ -41,13 +45,13 @@ Navigate to [`http://localhost:3000`](http://localhost:3000). Log in with the de ### Add Prometheus Data Source -Navigate to [`http://localhost:3000/connections/datasources/new`](http://localhost:3000/connections/datasources/new) and select Prometheus. +Navigate to [`http://localhost:3000/connections/datasources/new`](http://localhost:3000/connections/datasources/new) and select Prometheus. On Prometheus configuration page, we need to add the `Prometheus Server URL` in `Connection`. For this setup, Grafana and Prometheus are running in separate containers, but Docker creates DNS name for each containers. You can just use `http://prometheus:9090`. Click `Save & Test`. You should get a green check saying "Successfully queried the Prometheus API.". -### Import Dashboard +### Import Dashboard Navigate to [`http://localhost:3000/dashboard/import`](http://localhost:3000/dashboard/import), upload `grafana.json`, and select the `prometheus` datasource. You should see a screen that looks like the following: diff --git a/examples/other/logging_configuration.md b/examples/other/logging_configuration.md index 9ac8b13cd5eaf..acd9c1f2bc0a5 100644 --- a/examples/other/logging_configuration.md +++ b/examples/other/logging_configuration.md @@ -15,7 +15,6 @@ more-complex-and-more-flexible. - Leave `VLLM_CONFIGURE_LOGGING` unset or set `VLLM_CONFIGURE_LOGGING=1` and set `VLLM_LOGGING_CONFIG_PATH=` - ## Logging Configuration Environment Variables ### `VLLM_CONFIGURE_LOGGING` @@ -45,7 +44,6 @@ schema](https://docs.python.org/3/library/logging.config.html#dictionary-schema- If `VLLM_LOGGING_CONFIG_PATH` is specified, but `VLLM_CONFIGURE_LOGGING` is disabled, an error will occur while starting vLLM. - ## Examples ### Example 1: Customize vLLM root logger @@ -98,7 +96,6 @@ VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \ vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048 ``` - ### Example 2: Silence a particular vLLM logger To silence a particular vLLM logger, it is necessary to provide custom logging @@ -153,7 +150,6 @@ VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \ vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048 ``` - ### Example 3: Disable vLLM default logging configuration To disable vLLM's default logging configuration and silence all vLLM loggers, @@ -166,7 +162,6 @@ VLLM_CONFIGURE_LOGGING=0 \ vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048 ``` - ## Additional resources - [`logging.config` Dictionary Schema Details](https://docs.python.org/3/library/logging.config.html#dictionary-schema-details) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 78fa360f2dc96..0e7217fb3769e 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -3,7 +3,6 @@ # Dependencies for NVIDIA GPUs ray[default] >= 2.9 -nvidia-ml-py >= 12.560.30 # for pynvml package torch == 2.5.1 torchaudio==2.5.1 # These must be updated alongside torch diff --git a/requirements-neuron.txt b/requirements-neuron.txt index 5e08d101fcd61..09820c73e4e00 100644 --- a/requirements-neuron.txt +++ b/requirements-neuron.txt @@ -2,6 +2,5 @@ -r requirements-common.txt # Dependencies for Neuron devices -transformers-neuronx >= 0.13.0 torch-neuronx >= 2.5.0 neuronx-cc diff --git a/setup.py b/setup.py index a4043c43a7d5b..27e5aab760f9a 100755 --- a/setup.py +++ b/setup.py @@ -47,6 +47,12 @@ def load_module_from_path(module_name, path): "Building on %s, " "so vLLM may not be able to run correctly", sys.platform) VLLM_TARGET_DEVICE = "empty" +elif (sys.platform.startswith("linux") and torch.version.cuda is None + and os.getenv("VLLM_TARGET_DEVICE") is None + and torch.version.hip is None): + # if cuda or hip is not available and VLLM_TARGET_DEVICE is not set, + # fallback to cpu + VLLM_TARGET_DEVICE = "cpu" MAIN_CUDA_VERSION = "12.1" @@ -369,12 +375,7 @@ def _is_hip() -> bool: def _is_neuron() -> bool: - torch_neuronx_installed = True - try: - subprocess.run(["neuron-ls"], capture_output=True, check=True) - except (FileNotFoundError, PermissionError, subprocess.CalledProcessError): - torch_neuronx_installed = False - return torch_neuronx_installed or VLLM_TARGET_DEVICE == "neuron" + return VLLM_TARGET_DEVICE == "neuron" def _is_tpu() -> bool: @@ -482,7 +483,6 @@ def get_vllm_version() -> str: version = get_version( write_to="vllm/_version.py", # TODO: move this to pyproject.toml ) - sep = "+" if "+" not in version else "." # dev versions might contain + if _no_device(): @@ -520,7 +520,8 @@ def get_vllm_version() -> str: elif _is_tpu(): version += f"{sep}tpu" elif _is_cpu(): - version += f"{sep}cpu" + if envs.VLLM_TARGET_DEVICE == "cpu": + version += f"{sep}cpu" elif _is_xpu(): version += f"{sep}xpu" else: diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index da9239b094076..3ac948799d77c 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import pytest import torch from vllm import LLM, SamplingParams @@ -9,6 +10,32 @@ from ..utils import fork_new_process_for_each_test +@fork_new_process_for_each_test +def test_python_error(): + """ + Test if Python error occurs when there's low-level + error happening from the C++ side. + """ + allocator = CuMemAllocator.get_instance() + total_bytes = torch.cuda.mem_get_info()[1] + alloc_bytes = int(total_bytes * 0.7) + tensors = [] + with allocator.use_memory_pool(): + # allocate 70% of the total memory + x = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda') + tensors.append(x) + # release the memory + allocator.sleep() + + # allocate more memory than the total memory + y = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda') + tensors.append(y) + with pytest.raises(RuntimeError): + # when the allocator is woken up, it should raise an error + # because we don't have enough memory + allocator.wake_up() + + @fork_new_process_for_each_test def test_basic_cumem(): # some tensors from default memory pool @@ -88,10 +115,16 @@ def model(x): @fork_new_process_for_each_test -def test_end_to_end(): +@pytest.mark.parametrize( + "model", + [ + "meta-llama/Llama-3.2-1B", # sleep mode with safetensors + "facebook/opt-125m" # sleep mode with pytorch checkpoint + ]) +def test_end_to_end(model): free, total = torch.cuda.mem_get_info() used_bytes_baseline = total - free # in case other process is running - llm = LLM("meta-llama/Llama-3.2-1B", enable_sleep_mode=True) + llm = LLM(model, enable_sleep_mode=True) prompt = "How are you?" sampling_params = SamplingParams(temperature=0, max_tokens=10) output = llm.generate(prompt, sampling_params) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 5b6741d74efc0..5d7cb9e408909 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -215,7 +215,7 @@ def iter_params(self, model_name: str): "Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True), "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(), - "fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True), + "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 # [Encoder-decoder] # TODO: Implement PP # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(), @@ -234,7 +234,7 @@ def iter_params(self, model_name: str): # [MULTIMODAL GENERATION] "OpenGVLab/InternVL2-1B", "microsoft/Phi-3-vision-128k-instruct", - "fixie-ai/ultravox-v0_3", + "fixie-ai/ultravox-v0_5-llama-3_2-1b", # [LANGUAGE GENERATION - HYBRID ARCH] "ai21labs/Jamba-tiny-dev", ] diff --git a/tests/engine/test_custom_executor.py b/tests/engine/test_executor.py similarity index 79% rename from tests/engine/test_custom_executor.py rename to tests/engine/test_executor.py index 3e77faecbd3f5..84cc3ed63bb93 100644 --- a/tests/engine/test_custom_executor.py +++ b/tests/engine/test_executor.py @@ -55,6 +55,7 @@ def test_custom_executor(model, tmp_path): engine_args = EngineArgs( model=model, distributed_executor_backend=CustomUniExecutor, + enforce_eager=True, # reduce test time ) engine = LLMEngine.from_engine_args(engine_args) sampling_params = SamplingParams(max_tokens=1) @@ -75,7 +76,10 @@ def test_custom_executor_async(model, tmp_path): assert not os.path.exists(".marker") engine_args = AsyncEngineArgs( - model=model, distributed_executor_backend=CustomUniExecutorAsync) + model=model, + distributed_executor_backend=CustomUniExecutorAsync, + enforce_eager=True, # reduce test time + ) engine = AsyncLLMEngine.from_engine_args(engine_args) sampling_params = SamplingParams(max_tokens=1) @@ -89,3 +93,18 @@ async def t(): assert os.path.exists(".marker") finally: os.chdir(cwd) + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +def test_respect_ray(model): + # even for TP=1 and PP=1, + # if users specify ray, we should use ray. + # users might do this if they want to manage the + # resources using ray. + engine_args = EngineArgs( + model=model, + distributed_executor_backend="ray", + enforce_eager=True, # reduce test time + ) + engine = LLMEngine.from_engine_args(engine_args) + assert engine.model_executor.uses_ray diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py index eac76f2ba0fa5..85156d6931c8c 100644 --- a/tests/entrypoints/offline_mode/test_offline_mode.py +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -4,6 +4,7 @@ import sys import pytest +import urllib3 from vllm import LLM from vllm.distributed import cleanup_dist_env_and_memory @@ -28,6 +29,15 @@ "tensor_parallel_size": 1, "tokenizer_mode": "mistral", }, + { + "model": "sentence-transformers/all-MiniLM-L12-v2", + "enforce_eager": True, + "gpu_memory_utilization": 0.20, + "max_model_len": 64, + "max_num_batched_tokens": 64, + "max_num_seqs": 64, + "tensor_parallel_size": 1, + }, ] @@ -47,6 +57,16 @@ def test_offline_mode(monkeypatch): # Set HF to offline mode and ensure we can still construct an LLM try: monkeypatch.setenv("HF_HUB_OFFLINE", "1") + monkeypatch.setenv("VLLM_NO_USAGE_STATS", "1") + + def disable_connect(*args, **kwargs): + raise RuntimeError("No http calls allowed") + + monkeypatch.setattr(urllib3.connection.HTTPConnection, "connect", + disable_connect) + monkeypatch.setattr(urllib3.connection.HTTPSConnection, "connect", + disable_connect) + # Need to re-import huggingface_hub and friends to setup offline mode _re_import_modules() # Cached model files should be used in offline mode @@ -56,6 +76,7 @@ def test_offline_mode(monkeypatch): # Reset the environment after the test # NB: Assuming tests are run in online mode monkeypatch.delenv("HF_HUB_OFFLINE") + monkeypatch.delenv("VLLM_NO_USAGE_STATS") _re_import_modules() pass diff --git a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py b/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py index f7b81be48bd11..fdadb2e21ff80 100644 --- a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py +++ b/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py @@ -15,32 +15,62 @@ end_token = "" SIMPLE_REASONING = { - "output": "This is a reasoning sectionThis is the rest", + "output": "This is a reasoning sectionThis is the rest", "reasoning_content": "This is a reasoning section", "content": "This is the rest", } COMPLETE_REASONING = { - "output": "This is a reasoning section", + "output": "This is a reasoning section", "reasoning_content": "This is a reasoning section", "content": None, } NO_REASONING = { - "output": "This is a reasoning section", + "output": "This is content", "reasoning_content": None, - "content": "This is a reasoning section", + "content": "This is content", +} +NO_REASONING_STREAMING = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, } MULTIPLE_LINES = { - "output": "This\nThatThis is the rest\nThat", + "output": "This\nThatThis is the rest\nThat", "reasoning_content": "This\nThat", "content": "This is the rest\nThat", } SHORTEST_REASONING_NO_STREAMING = { - "output": "This is the rest", + "output": "This is the rest", "reasoning_content": "", "content": "This is the rest", } SHORTEST_REASONING = { - "output": "This is the rest", + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", +} +REASONING_WITH_THINK = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} +COMPLETE_REASONING_WITH_THINK = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, +} +MULTIPLE_LINES_WITH_THINK = { + "output": "This\nThatThis is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", +} +SHORTEST_REASONING_NO_STREAMING_WITH_THINK = { + "output": "This is the rest", + "reasoning_content": "", + "content": "This is the rest", +} +SHORTEST_REASONING_WITH_THINK = { + "output": "This is the rest", "reasoning_content": None, "content": "This is the rest", } @@ -49,37 +79,37 @@ pytest.param( False, SIMPLE_REASONING, - id="simple_streaming", + id="simple_reasoning", ), pytest.param( True, SIMPLE_REASONING, - id="simple_streaming", + id="simple_reasoning_streaming", ), pytest.param( False, COMPLETE_REASONING, - id="complete_streaming", + id="complete_reasoning", ), pytest.param( True, COMPLETE_REASONING, - id="complete_streaming", + id="complete_reasoning_streaming", ), pytest.param( False, NO_REASONING, - id="no_streaming", + id="no_reasoning_token", ), pytest.param( True, - NO_REASONING, - id="no_streaming", + NO_REASONING_STREAMING, + id="no_reasoning_token_streaming", ), pytest.param( False, MULTIPLE_LINES, - id="multiple_lines_streaming", + id="multiple_lines", ), pytest.param( True, @@ -89,23 +119,65 @@ pytest.param( True, SHORTEST_REASONING, - id="shortest_streaming", + id="shortest", ), pytest.param( False, SHORTEST_REASONING_NO_STREAMING, id="shortest_streaming", ), + pytest.param( + False, + REASONING_WITH_THINK, + id="reasoning_with_think", + ), + pytest.param( + True, + REASONING_WITH_THINK, + id="reasoning_with_think_streaming", + ), + pytest.param( + False, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think", + ), + pytest.param( + True, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think_streaming", + ), + pytest.param( + False, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think", + ), + pytest.param( + True, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think_streaming", + ), + pytest.param( + False, + SHORTEST_REASONING_NO_STREAMING_WITH_THINK, + id="shortest_with_think", + ), + pytest.param( + True, + SHORTEST_REASONING_WITH_THINK, + id="shortest_with_think_streaming", + ), ] +# Global tokenizer initialization to avoid repeated loading +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") +tokenizer.add_tokens([start_token, end_token]) + @pytest.mark.parametrize("streaming, param_dict", TEST_CASES) def test_reasoning( streaming: bool, param_dict: dict, ): - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") - tokenizer.add_tokens([start_token, end_token]) output = tokenizer.tokenize(param_dict["output"]) # decode everything to tokens output_tokens: List[str] = [ diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 6e206dfd99b6a..fe7299a48e6f6 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -11,7 +11,7 @@ from ...utils import RemoteOpenAIServer -MODEL_NAME = "fixie-ai/ultravox-v0_3" +MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" TEST_AUDIO_URLS = [ AudioAsset("winning_call").url, ] @@ -83,7 +83,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=202, total_tokens=212) + completion_tokens=10, prompt_tokens=201, total_tokens=211) message = choice.message message = chat_completion.choices[0].message @@ -140,7 +140,7 @@ async def test_single_chat_session_audio_base64encoded( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=202, total_tokens=212) + completion_tokens=10, prompt_tokens=201, total_tokens=211) message = choice.message message = chat_completion.choices[0].message @@ -196,7 +196,7 @@ async def test_single_chat_session_input_audio( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=202, total_tokens=212) + completion_tokens=10, prompt_tokens=201, total_tokens=211) message = choice.message message = chat_completion.choices[0].message diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index de2333901cc91..8c1bb1a897e37 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -203,6 +203,8 @@ async def test_metrics_counts(server: RemoteOpenAIServer, "vllm:num_requests_running", "vllm:num_requests_waiting", "vllm:gpu_cache_usage_perc", + "vllm:gpu_prefix_cache_queries", + "vllm:gpu_prefix_cache_hits", "vllm:prompt_tokens_total", "vllm:generation_tokens_total", "vllm:request_success_total", diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 029c9b038b047..c954fca696ffa 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -92,7 +92,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=775, total_tokens=785) + completion_tokens=10, prompt_tokens=774, total_tokens=784) message = choice.message message = chat_completion.choices[0].message @@ -185,7 +185,7 @@ async def test_single_chat_session_image_base64encoded( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=775, total_tokens=785) + completion_tokens=10, prompt_tokens=774, total_tokens=784) message = choice.message message = chat_completion.choices[0].message diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index f2ff4a0b07a5f..cee5274561f47 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -93,5 +93,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, assert len(embeddings.data) == 1 assert len(embeddings.data[0].embedding) == 3072 assert embeddings.usage.completion_tokens == 0 - assert embeddings.usage.prompt_tokens == 764 - assert embeddings.usage.total_tokens == 764 + assert embeddings.usage.prompt_tokens == 763 + assert embeddings.usage.total_tokens == 763 diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 5c469007af23e..c52fa905c80b3 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -21,7 +21,7 @@ EXAMPLES_DIR = VLLM_PATH / "examples" PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" -ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_3" +ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b" QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" diff --git a/tests/kernels/test_mamba_mixer2.py b/tests/kernels/test_mamba_mixer2.py new file mode 100644 index 0000000000000..8c441fcbe61e2 --- /dev/null +++ b/tests/kernels/test_mamba_mixer2.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Tuple + +import pytest +import torch + +from tests.utils import multi_gpu_test +from vllm.distributed.parallel_state import (init_distributed_environment, + initialize_model_parallel) +from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seq_len", [128]) +@pytest.mark.parametrize( + "hidden_size_n_groups", + [ + (64, 1), + (64, 2), + (64, 4), # hidden_size be divisible by num_gpus + (100, 5), # and n_groups must divide hidden_size + ]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_mixer2_gated_norm_multi_gpu( + batch_size: int, + seq_len: int, + hidden_size_n_groups: Tuple[int, int], + dtype: torch.dtype, + device: str = 'cuda', +): + hidden_size, n_groups = hidden_size_n_groups + num_processes = 2 + + def run_torch_spawn(fn, nprocs): + # need to use torch.mp.spawn otherwise will have problems with + # torch.distributed and cuda + torch.multiprocessing.spawn(fn, + args=( + num_processes, + batch_size, + seq_len, + hidden_size, + n_groups, + dtype, + device, + ), + nprocs=nprocs) + + run_torch_spawn(mixer2_gated_norm_tensor_parallel, 2) + + +def mixer2_gated_norm_tensor_parallel( + local_rank: int, + world_size: int, + batch_size: int, + seq_len: int, + hidden_size: int, + n_groups: int, + dtype: torch.dtype, + device: str, +): + current_platform.seed_everything(0) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # create random weights an inputs + weight = torch.rand((hidden_size, ), dtype=dtype, device=device) + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + gate_states = torch.randn(batch_size, seq_len, hidden_size) + + # create gated-norm with TP + mixer = Mixer2RMSNormGated( + full_hidden_size=hidden_size, + full_n_groups=n_groups, + ) + mixer.weight.weight_loader(mixer.weight, weight) # load + + # create gated-norm without TP to compute reference + # - utilize mock patching to disable TP when + with (unittest.mock.patch( + "vllm.model_executor.layers.mamba.mamba_mixer2." + "get_tensor_model_parallel_world_size", + return_value=1), + unittest.mock.patch( + "vllm.model_executor.layers.mamba.mamba_mixer2." + "get_tensor_model_parallel_rank", + return_value=0)): + mixer_single_gpu = Mixer2RMSNormGated( + full_hidden_size=hidden_size, + full_n_groups=n_groups, + ) + # assign weight to single-gpu mixer + mixer_single_gpu.weight.data = weight + + # generate and compare + N = hidden_size // world_size + output = mixer( + hidden_states[..., local_rank * N:(local_rank + 1) * N], + gate_states[..., local_rank * N:(local_rank + 1) * N], + ) + ref_output = mixer_single_gpu(hidden_states, gate_states) + torch.allclose(output, + ref_output[..., local_rank * N:(local_rank + 1) * N], + atol=1e-3, + rtol=1e-3) diff --git a/tests/kernels/test_mamba_ssm_ssd.py b/tests/kernels/test_mamba_ssm_ssd.py new file mode 100644 index 0000000000000..882513116ed6d --- /dev/null +++ b/tests/kernels/test_mamba_ssm_ssd.py @@ -0,0 +1,304 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Tuple + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +from vllm.model_executor.layers.mamba.ops.ssd_combined import ( + mamba_chunk_scan_combined) +from vllm.platforms import current_platform + +# Added by the IBM Team, 2024 + +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py + + +# this is the segsum implementation taken from above +def segsum(x): + """Calculates segment sum.""" + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), + diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), + diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): + """ + Arguments: + X: (batch, length, n_heads, d_head) + A: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + Y: (batch, length, n_heads, d_head) + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + assert X.shape[1] % block_len == 0 + + # Rearrange into blocks/chunks + X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len) + for x in (X, A, B, C)) + + A = rearrange(A, "b c l h -> b h c l") + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + L = torch.exp(segsum(A)) + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at + # chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms + # (diagonal and off-diagonal blocks) + Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") + return Y, final_state + + +def generate_random_inputs(batch_size, + seqlen, + n_heads, + d_head, + itype, + device='cuda'): + + current_platform.seed_everything(0) + A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device))) + dt = F.softplus( + torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - + 4) + X = torch.randn((batch_size, seqlen, n_heads, d_head), + dtype=itype, + device=device) + B = torch.randn((batch_size, seqlen, n_heads, d_head), + dtype=itype, + device=device) + C = torch.randn((batch_size, seqlen, n_heads, d_head), + dtype=itype, + device=device) + + return A, dt, X, B, C + + +def generate_continous_batched_examples(example_lens_by_batch, + num_examples, + full_length, + last_taken, + exhausted, + n_heads, + d_head, + itype, + device='cuda'): + + # this function generates a random examples of certain length + # and then cut according to "example_lens_by_batch" and feed + # them in continuous batches to the kernels + + # generate the full-length example + A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads, + d_head, itype) + + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), + A * dt, + B, + C, + block_len=full_length // 4) + + # internal function that outputs a cont batch of examples + # given a tuple of lengths for each example in the batch + # e.g., example_lens=(8, 4) means take 8 samples from first eg, + # 4 examples from second eg, etc + def get_continuous_batch(example_lens: Tuple[int, ...]): + + indices = [] + for i, x in enumerate(example_lens): + c = last_taken.get(i, 0) + indices.append((c, c + x)) + last_taken[i] = (c + x) % full_length + exhausted[i] = last_taken[i] == 0 + + return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices) + ]).unsqueeze(0) for x in (dt, X, B, C)) + + # internal function that maps "n" to the appropriate right boundary + # value when forming continuous batches from examples of length given + # by "full_length". + # - e.g., when n > full_length, returns n % full_length + # when n == full_length, returns full_length + def end_boundary(n: int): + return n - ((n - 1) // full_length) * full_length + + IND_E = None + for spec in example_lens_by_batch: + + # get the (maybe partial) example seen in this cont batch + dt2, X2, B2, C2 = get_continuous_batch(spec) + + # get the metadata + cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0) + sed_idx = torch.zeros(cu_seqlens[-1], + dtype=torch.int32, + device=cu_seqlens.device) + for i, (srt, end) in enumerate(zip( + cu_seqlens, + cu_seqlens[1:], + )): + sed_idx[srt:end] = i + + # for cont batch + if IND_E is None: + IND_S = [0 for _ in range(len(spec))] + else: + IND_S = [x % full_length for x in IND_E] + IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] + + yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], + cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32]) +@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128]) +@pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)]) +def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, + itype): + + # this tests the kernels on a single example (no batching) + + # set seed + batch_size = 1 # batch_size + # ssd_minimal_discrete requires chunk_size divide seqlen + # - this is only required for generating the reference seqs, + # it is not an operational limitation. + seqlen, chunk_size = seq_len_chunk_size + + A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, + d_head, itype) + + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt, + B, C, chunk_size) + + Y, final_state = mamba_chunk_scan_combined(X, + dt, + A, + B, + C, + chunk_size, + D=None, + return_final_states=True) + + # just test the last in sequence + torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3) + + # just test the last head + # NOTE, in the kernel we always cast states to fp32 + torch.allclose(final_state[:, -1], + final_state_min[:, -1].to(torch.float32), + atol=1e-3, + rtol=1e-3) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("n_heads", [4, 8, 13]) +@pytest.mark.parametrize("d_head", [5, 16, 21, 32]) +@pytest.mark.parametrize( + "seq_len_chunk_size_cases", + [ + + # small-ish chunk_size (8) + (64, 8, 2, [(64, 32), (64, 32)]), + (64, 8, 2, [(32, 32), (32, 32), (32, 32)]), + (64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary + (64, 8, 2, [(4, 4), (4, 4), (4, 4), + (4, 4)]), # chunk_size larger than cont batches + (64, 8, 5, [ + (64, 32, 16, 8, 8), + (8, 16, 32, 16, 8), + (8, 8, 16, 32, 16), + ]), # mode examples with varied lengths + + # odd chunk_size + (64, 29, 2, [(11, 4), (13, 23), (19, 22), + (21, 15)]), # irregular sizes + + # large-ish chunk_size (256) + (64, 256, 1, [(5, ), (1, ), (1, ), + (1, )]), # irregular sizes with small sequences + (64, 256, 2, [(5, 30), (1, 2), (1, 2), + (1, 2)]), # irregular sizes with small sequences + ]) +def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, + itype): + + # this test with multiple examples in a continuous batch + # (i.e. chunked prefill) + + seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases + + # hold state during the cutting process so we know if an + # example has been exhausted and needs to cycle + last_taken: Dict = {} # map: eg -> pointer to last taken sample + exhausted: Dict = {} # map: eg -> boolean indicating example is exhausted + + states = None + for Y_min, cu_seqlens, sed_idx, (A, dt, X, B, + C) in generate_continous_batched_examples( + cases, num_examples, seqlen, + last_taken, exhausted, n_heads, + d_head, itype): + + Y, new_states = mamba_chunk_scan_combined( + X, + dt, + A, + B, + C, + chunk_size, + D=None, + cu_seqlens=cu_seqlens, + seq_idx=sed_idx, + return_varlen_states=True, + initial_states=states, + ) + + # just test the last in sequence + for i in range(num_examples): + + # just test one dim and dstate + Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] + Y_min_eg = Y_min[i][:, 0, 0] + torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3) + + # update states + states = new_states + for i, clear in exhausted.items(): + if clear: + states[i].fill_(0.) + exhausted[i] = False diff --git a/tests/lora/test_punica_ops.py b/tests/lora/test_punica_ops.py new file mode 100644 index 0000000000000..032e20470bcd3 --- /dev/null +++ b/tests/lora/test_punica_ops.py @@ -0,0 +1,652 @@ +# SPDX-License-Identifier: Apache-2.0 +from threading import Lock +from typing import List + +import pytest +import torch + +import vllm.lora.ops.triton_ops # noqa: F401 +from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) +from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm.platforms import current_platform + +from .utils import (PunicaTensors, assert_close, generate_data, + generate_data_for_expand_nslices, + generate_data_for_nslices) + + +# Utility shrink and expand operations used as reference implementations. +def sgmv_shrink_for_nslices( + nslices: int, inputs_tensor: torch.Tensor, + lora_weights_lst: List[torch.Tensor], out_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, + prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int, + num_tokens: int, scaling: float): + """ + Wrapper around sgmv_shrink that handles any nslices. + """ + for index in range(nslices): + sgmv_shrink( + inputs_tensor, + lora_weights_lst[index], + out_tensor[index], + b_seq_start_loc, + seq_len_tensor, + prompt_lora_mapping, + batches, + max_seq_length, + num_tokens, + scaling, + ) + + +def sgmv_expand_for_nslices(nslices: int, hidden_size: int, + inputs_tensor: torch.Tensor, + lora_weights_lst: List[torch.Tensor], + out_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + prompt_lora_mapping: torch.Tensor, batches: int, + max_seq_length: int, num_tokens: int, + add_inputs: bool) -> None: + """ + Wrapper around sgmv_expand that handles any nslices. + """ + if nslices == 1: + # Verify the torch's sgmv_expand op + sgmv_expand( + inputs_tensor[0], + lora_weights_lst[0], + out_tensor, + b_seq_start_loc, + seq_len_tensor, + prompt_lora_mapping, + batches, + max_seq_length, + num_tokens, + add_inputs=add_inputs, + ) + else: + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + sgmv_expand_slice( + inputs_tensor[index], + lora_weights, + out_tensor, + b_seq_start_loc, + seq_len_tensor, + prompt_lora_mapping, + batches, + max_seq_length, + num_tokens, + slice_offset, + hidden_size, + add_inputs=add_inputs, + ) + slice_offset += hidden_size + + +_dict_lock = Lock() + + +def check_sgmv_shrink(batches: int, num_loras: int, rank: int, + hidden_size: int, nslices: int, dtype: torch.dtype, + device: str, seq_length: int, scaling: float): + """ + Compare outputs of vllm.sgmv_shrink kernel against a reference + implementation. + """ + data: PunicaTensors = generate_data_for_nslices( + batches, + hidden_size, + num_loras, + rank, + seq_length, + nslices, + dtype, + "shrink", + device, + ) + max_seq_length, token_nums = data.meta() + + # Preventing cache error pointer. + with _dict_lock: + _LORA_A_PTR_DICT.clear() + torch.ops.vllm.sgmv_shrink( + data.inputs_tensor, + data.lora_weights, + data.our_out_tensor, + data.b_seq_start_loc, + data.seq_len_tensor, + data.prompt_lora_mapping, + batches, + max_seq_length, + token_nums, + scaling, + ) + + sgmv_shrink_for_nslices( + nslices, + data.inputs_tensor, + data.lora_weights, + data.ref_out_tensor, + data.b_seq_start_loc, + data.seq_len_tensor, + data.prompt_lora_mapping, + batches, + max_seq_length, + token_nums, + scaling, + ) + assert_close(data.our_out_tensor, data.ref_out_tensor) + + +def check_sgmv_expand(batches: int, num_loras: int, rank: int, + hidden_size: int, nslices: int, dtype: torch.dtype, + device: str, seq_length: int, add_inputs: bool): + """ + Compare outputs of vllm.sgmv_expand kernel against a reference + implementation. + """ + data: PunicaTensors = generate_data_for_nslices( + batches, + hidden_size, + num_loras, + rank, + seq_length, + nslices, + dtype, + "expand", + device, + ) + + max_seq_length, token_nums = data.meta() + + with _dict_lock: + _LORA_B_PTR_DICT.clear() + torch.ops.vllm.sgmv_expand( + data.inputs_tensor, + data.lora_weights, + data.our_out_tensor, + data.b_seq_start_loc, + data.seq_len_tensor, + data.prompt_lora_mapping, + batches, + max_seq_length, + token_nums, + offset_start=0, + add_inputs=add_inputs, + ) + + sgmv_expand_for_nslices(nslices, + hidden_size, + data.inputs_tensor, + data.lora_weights, + data.ref_out_tensor, + data.b_seq_start_loc, + data.seq_len_tensor, + data.prompt_lora_mapping, + batches, + max_seq_length, + token_nums, + add_inputs=add_inputs) + + assert_close(data.our_out_tensor, data.ref_out_tensor) + + +def check_bgmv_shrink(batches: int, num_loras: int, rank: int, + hidden_size: int, dtype: torch.dtype, device: str, + scaling: float): + """ + Compare vllm.bgmv_shrink against a reference implementation. + """ + seq_length = 1 + data: PunicaTensors = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + "shrink", + device, + ) + + torch.ops.vllm.bgmv_shrink( + data.inputs_tensor, + data.lora_weights, + data.our_out_tensor, + data.token_lora_mapping, + scaling, + ) + + bgmv_shrink( + data.inputs_tensor, + data.lora_weights, + data.ref_out_tensor, + data.token_lora_mapping, + scaling, + ) + + data.ref_out_tensor = data.ref_out_tensor.to(torch.float32) + assert_close(data.our_out_tensor, data.ref_out_tensor) + + +def check_bgmv_expand(batches: int, num_loras: int, rank: int, + hidden_size: int, dtype: torch.dtype, device: str, + add_inputs: bool): + """ + Compare vllm.bgmv_expand against a reference implementation. + """ + seq_length = 1 + data: PunicaTensors = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + "expand", + device, + ) + + torch.ops.vllm.bgmv_expand( + data.inputs_tensor, + data.lora_weights, + data.our_out_tensor, + data.token_lora_mapping, + add_inputs=add_inputs, + ) + bgmv_expand( + data.inputs_tensor, + data.lora_weights, + data.ref_out_tensor, + data.token_lora_mapping, + add_inputs=add_inputs, + ) + assert_close(data.our_out_tensor, data.ref_out_tensor) + + +def check_bgmv_expand_slice(batches: int, num_loras: int, rank: int, + hidden_size: int, nslices: int, dtype: torch.dtype, + device: str, add_inputs: bool): + """ + Compare vllm.bgmv_expand_slice against a reference implementation. + """ + seq_length = 1 + data: PunicaTensors = generate_data_for_expand_nslices( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + nslices, + device, + ) + + slice_offset = 0 + for index in range(nslices): + torch.ops.vllm.bgmv_expand_slice( + data.inputs_tensor, + data.lora_weights[index], + data.our_out_tensor, + data.token_lora_mapping, + slice_offset, + slice_size=hidden_size, + add_inputs=add_inputs, + ) + bgmv_expand_slice( + data.inputs_tensor, + data.lora_weights[index], + data.ref_out_tensor, + data.token_lora_mapping, + slice_offset, + slice_size=hidden_size, + add_inputs=add_inputs, + ) + + slice_offset += hidden_size + assert_close(data.our_out_tensor, data.ref_out_tensor) + + +# Tests +# We test the punica kernels along 2 verticals mainly. +# 1. Variations in hidden_dim size +# 2. Variations in all other parameters like (batch_size, max_rank, num_loras +# etc.) + +# We have collected the hidden_sizes included in the LoRA models +# currently supported by vLLM. It tests whether the corresponding Triton +# kernel can run normally when tensor parallelism is set to +# [1, 2, 4, 8, 16, 32, 64]. +HIDDEN_SIZES = [ + 128, + 256, + 512, + 896, + 1024, + 1152, + 1216, + 1280, + 1536, + 1664, + 2048, + 2240, + 2304, + 2368, + 2432, + 2560, + 2752, + 3072, + 3328, + 3456, + 3584, + 3712, + 4096, + 4480, + 4608, + 4736, + 4864, + 5120, + 5504, + 5632, + 5888, + 6144, + 6400, + 6848, + 6912, + 7168, + 7424, + 8192, + 8960, + 9216, + 9472, + 10240, + 11008, + 11264, + 13824, + 14336, + 14784, + 14848, + 15360, + 18944, + 22016, + 22528, + 24576, + 27392, + 27648, + 29568, + 29696, + 32000, + 32256, + 32512, + 32768, + 33024, + 36864, + 43264, + 49152, + 49408, + 60544, + 60672, + 64000, + 64256, + 102400, + 102656, + 128000, + 128256, +] +#The size of TP +divisibility = [1, 2, 8, 16, 64] + +all_hidden_size = [] +for div in divisibility: + for hidden_size in HIDDEN_SIZES: + all_hidden_size.append(hidden_size // div) + +HIDDEN_SIZES = list(set(all_hidden_size)) + +# Test params that focuses on hidden_size variation. +hs_test_params = { + "hidden_sizes": HIDDEN_SIZES, + "batches": [4], + "num_loras": [4], + "max_ranks": [32], +} + +# General tests params that tests for variations in all dimensions +# except hidden_size. +test_params = { + "hidden_sizes": [2049], + "batches": [1, 4, 16, 32], + "num_loras": [1, 8, 32, 128], + "max_ranks": [1, 4, 8, 16, 32, 64, 128, 256], +} + +DTYPES = [torch.float16, torch.bfloat16] +DEVICES = [f"cuda:{0}"] +SEED = [0] + + +@pytest.mark.parametrize("batches", test_params['batches']) +@pytest.mark.parametrize("num_loras", test_params['num_loras']) +@pytest.mark.parametrize("rank", test_params['max_ranks']) +@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes']) +@pytest.mark.parametrize("nslices", [1, 2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +def test_punica_sgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + seed: int, + op_type: str, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + + if op_type == "shrink": + check_sgmv_shrink(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + scaling=0.5) + else: + check_sgmv_expand(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + add_inputs=True) + + +@pytest.mark.parametrize("batches", hs_test_params['batches']) +@pytest.mark.parametrize("num_loras", hs_test_params['num_loras']) +@pytest.mark.parametrize("rank", hs_test_params['max_ranks']) +@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes']) +@pytest.mark.parametrize("nslices", [1, 2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +def test_punica_sgmv_hidden_size( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + seed: int, + op_type: str, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + + if op_type == "shrink": + check_sgmv_shrink(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + scaling=0.5) + else: + check_sgmv_expand(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + add_inputs=True) + + +@pytest.mark.parametrize("batches", test_params['batches']) +@pytest.mark.parametrize("num_loras", test_params['num_loras']) +@pytest.mark.parametrize("rank", test_params['max_ranks']) +@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes']) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +def test_punica_bgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + dtype: torch.dtype, + device: str, + seed: int, + op_type: str, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + + if op_type == "shrink": + check_bgmv_shrink(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + dtype=dtype, + device=device, + scaling=0.5) + else: + check_bgmv_expand(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + dtype=dtype, + device=device, + add_inputs=True) + + +@pytest.mark.parametrize("batches", hs_test_params['batches']) +@pytest.mark.parametrize("num_loras", hs_test_params['num_loras']) +@pytest.mark.parametrize("rank", hs_test_params['max_ranks']) +@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes']) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +def test_punica_bgmv_hidden_size( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + dtype: torch.dtype, + device: str, + seed: int, + op_type: str, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + + if op_type == "shrink": + check_bgmv_shrink(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + dtype=dtype, + device=device, + scaling=0.5) + else: + check_bgmv_expand(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + dtype=dtype, + device=device, + add_inputs=True) + + +@pytest.mark.parametrize("batches", test_params['batches']) +@pytest.mark.parametrize("num_loras", test_params['num_loras']) +@pytest.mark.parametrize("rank", test_params['max_ranks']) +@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes']) +@pytest.mark.parametrize("nslices", [2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +def test_punica_bgmv_expand_nslices(batches: int, num_loras: int, rank: int, + hidden_size: int, nslices: int, + dtype: torch.dtype, device: str, + seed: int): + + torch.set_default_device(device) + current_platform.seed_everything(seed) + + check_bgmv_expand_slice(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + add_inputs=True) + + +@pytest.mark.parametrize("batches", hs_test_params['batches']) +@pytest.mark.parametrize("num_loras", hs_test_params['num_loras']) +@pytest.mark.parametrize("rank", hs_test_params['max_ranks']) +@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes']) +@pytest.mark.parametrize("nslices", [2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +def test_punica_bgmv_expand_nslices_hidden_size(batches: int, num_loras: int, + rank: int, hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, seed: int): + + torch.set_default_device(device) + current_platform.seed_everything(seed) + + check_bgmv_expand_slice(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + add_inputs=True) diff --git a/tests/lora/test_punica_ops_sizes.py b/tests/lora/test_punica_ops_sizes.py deleted file mode 100644 index ecd3bc4978f39..0000000000000 --- a/tests/lora/test_punica_ops_sizes.py +++ /dev/null @@ -1,401 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -This script is mainly used to tests various hidden_sizes. We have collected the -hidden_sizes included in the LoRA models currently supported by vLLM. It tests -whether the corresponding Triton kernel can run normally when tensor parallelism -is set to [1, 2, 4, 8, 16, 32, 64]. -""" -from threading import Lock - -import pytest -import torch - -import vllm.lora.ops.triton_ops # noqa: F401 -from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) -from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT -from vllm.platforms import current_platform - -from .utils import (assert_close, generate_data, - generate_data_for_expand_nslices, - generate_data_for_nslices) - -HIDDEN_SIZES = [ - 128, - 256, - 512, - 896, - 1024, - 1152, - 1216, - 1280, - 1536, - 1664, - 2048, - 2240, - 2304, - 2368, - 2432, - 2560, - 2752, - 3072, - 3328, - 3456, - 3584, - 3712, - 4096, - 4480, - 4608, - 4736, - 4864, - 5120, - 5504, - 5632, - 5888, - 6144, - 6400, - 6848, - 6912, - 7168, - 7424, - 8192, - 8960, - 9216, - 9472, - 10240, - 11008, - 11264, - 13824, - 14336, - 14784, - 14848, - 15360, - 18944, - 22016, - 22528, - 24576, - 27392, - 27648, - 29568, - 29696, - 32000, - 32256, - 32512, - 32768, - 33024, - 36864, - 43264, - 49152, - 49408, - 60544, - 60672, - 64000, - 64256, - 102400, - 102656, - 128000, - 128256, -] -#The size of TP -divisibility = [1, 2, 8, 16, 64] - -all_hidden_size = [] -for div in divisibility: - for hidden_size in HIDDEN_SIZES: - all_hidden_size.append(hidden_size // div) - -HIDDEN_SIZES = list(set(all_hidden_size)) - -BATCHES = [4] -NUM_LORA = [4] -DTYPES = [torch.float16, torch.bfloat16] -MAX_RANKS = [32] -SCALES = [0.5] -SEED = [0] -DEVICES = [f"cuda:{0}"] - -_dict_lock = Lock() - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("scaling", SCALES) -@pytest.mark.parametrize("nslices", [1, 2, 3]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_sgmv( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - scaling: float, - nslices: int, - dtype: torch.dtype, - op_type: str, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 128 - ( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - ref_out_tensor, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data_for_nslices( - batches, - hidden_size, - num_loras, - rank, - seq_length, - nslices, - dtype, - op_type, - device, - ) - max_seq_length = seq_len_tensor.max() - token_nums = seq_len_tensor.sum().item() - if isinstance(max_seq_length, tuple): - max_seq_length = max_seq_length[0].item() - else: - max_seq_length = max_seq_length.item() - if op_type == "shrink": - # Preventing cache error pointer. - with _dict_lock: - _LORA_A_PTR_DICT.clear() - torch.ops.vllm.sgmv_shrink( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) - for index in range(nslices): - sgmv_shrink( - inputs_tensor, - lora_weights_lst[index], - ref_out_tensor[index], - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) - - else: - with _dict_lock: - _LORA_B_PTR_DICT.clear() - torch.ops.vllm.sgmv_expand( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - offset_start=0, - add_inputs=True, - ) - if nslices == 1: - # Verify the torch's sgmv_expand op - sgmv_expand( - inputs_tensor[0], - lora_weights_lst[0], - ref_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - add_inputs=True, - ) - else: - slice_offset = 0 - for index in range(nslices): - lora_weights = lora_weights_lst[index] - sgmv_expand_slice( - inputs_tensor[index], - lora_weights, - ref_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - slice_offset, - hidden_size, - add_inputs=True, - ) - slice_offset += hidden_size - - assert_close(our_out_tensor, ref_out_tensor) - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("scaling", SCALES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_bgmv( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - scaling: float, - dtype: torch.dtype, - op_type: str, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 1 - ( - inputs_tensor, - lora_weights, - our_out_tensor, - ref_out_tensor, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - op_type, - device, - ) - if op_type == "shrink": - torch.ops.vllm.bgmv_shrink( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - scaling, - ) - - bgmv_shrink( - inputs_tensor, - lora_weights, - ref_out_tensor, - indices, - scaling, - ) - - else: - torch.ops.vllm.bgmv_expand( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - add_inputs=True, - ) - bgmv_expand( - inputs_tensor, - lora_weights, - ref_out_tensor, - indices, - add_inputs=True, - ) - - if op_type == "shrink": - ref_out_tensor = ref_out_tensor.to(torch.float32) - assert_close(our_out_tensor, ref_out_tensor) - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("nslices", [2, 3]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_bgmv_expand_nslices( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - nslices: int, - dtype: torch.dtype, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 1 - ( - inputs_tensor, - lora_weights_lst, - our_outputs, - ref_outputs, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data_for_expand_nslices( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - nslices, - device, - ) - slice_offset = 0 - for index in range(nslices): - lora_weights = lora_weights_lst[index] - torch.ops.vllm.bgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) - bgmv_expand_slice( - inputs_tensor, - lora_weights, - ref_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) - - slice_offset += hidden_size - assert_close(our_outputs, ref_outputs) diff --git a/tests/lora/test_punica_ops_variation.py b/tests/lora/test_punica_ops_variation.py deleted file mode 100644 index 6d1d3c9430f38..0000000000000 --- a/tests/lora/test_punica_ops_variation.py +++ /dev/null @@ -1,317 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -This script is mainly used to test whether trtion kernels can run normally -under different conditions, including various batches, numbers of LoRA , and -maximum ranks. -""" -from threading import Lock - -import pytest -import torch - -# Enable custom op register -import vllm.lora.ops.triton_ops # noqa: F401 -from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) -from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT -from vllm.platforms import current_platform - -from .utils import (assert_close, generate_data, - generate_data_for_expand_nslices, - generate_data_for_nslices) - -HIDDEN_SIZES = [2049] - -BATCHES = [1, 4, 16, 32] -NUM_LORA = [1, 8, 32, 128] -DTYPES = [torch.float16, torch.bfloat16] -MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256] -SCALES = [0.5] -SEED = [0] -DEVICES = [f"cuda:{0}"] - -_dict_lock = Lock() - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("scaling", SCALES) -@pytest.mark.parametrize("nslices", [1, 2, 3]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_sgmv( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - scaling: float, - nslices: int, - dtype: torch.dtype, - op_type: str, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 128 - ( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - ref_out_tensor, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data_for_nslices( - batches, - hidden_size, - num_loras, - rank, - seq_length, - nslices, - dtype, - op_type, - device, - ) - max_seq_length = seq_len_tensor.max() - token_nums = seq_len_tensor.sum().item() - if isinstance(max_seq_length, tuple): - max_seq_length = max_seq_length[0].item() - else: - max_seq_length = max_seq_length.item() - if op_type == "shrink": - # Preventing cache error pointer. - with _dict_lock: - _LORA_A_PTR_DICT.clear() - torch.ops.vllm.sgmv_shrink( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) - for index in range(nslices): - sgmv_shrink( - inputs_tensor, - lora_weights_lst[index], - ref_out_tensor[index], - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) - - else: - with _dict_lock: - _LORA_B_PTR_DICT.clear() - torch.ops.vllm.sgmv_expand( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - offset_start=0, - add_inputs=True, - ) - slice_offset = 0 - if nslices == 1: - # Verify the torch's sgmv_expand op - sgmv_expand( - inputs_tensor[0], - lora_weights_lst[0], - ref_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - add_inputs=True, - ) - else: - for index in range(nslices): - lora_weights = lora_weights_lst[index] - sgmv_expand_slice( - inputs_tensor[index], - lora_weights, - ref_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - slice_offset, - hidden_size, - add_inputs=True, - ) - slice_offset += hidden_size - - assert_close(our_out_tensor, ref_out_tensor) - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("scaling", SCALES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_bgmv( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - scaling: float, - dtype: torch.dtype, - op_type: str, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 1 - ( - inputs_tensor, - lora_weights, - our_out_tensor, - ref_out_tensor, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - op_type, - device, - ) - if op_type == "shrink": - torch.ops.vllm.bgmv_shrink( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - scaling, - ) - - bgmv_shrink( - inputs_tensor, - lora_weights, - ref_out_tensor, - indices, - scaling, - ) - - else: - torch.ops.vllm.bgmv_expand( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - add_inputs=True, - ) - bgmv_expand( - inputs_tensor, - lora_weights, - ref_out_tensor, - indices, - add_inputs=True, - ) - - if op_type == "shrink": - ref_out_tensor = ref_out_tensor.to(torch.float32) - assert_close(our_out_tensor, ref_out_tensor) - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("nslices", [2, 3]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_bgmv_expand_nslices( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - nslices: int, - dtype: torch.dtype, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 1 - ( - inputs_tensor, - lora_weights_lst, - our_outputs, - ref_outputs, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data_for_expand_nslices( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - nslices, - device, - ) - slice_offset = 0 - for index in range(nslices): - lora_weights = lora_weights_lst[index] - torch.ops.vllm.bgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) - bgmv_expand_slice( - inputs_tensor, - lora_weights, - ref_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) - - slice_offset += hidden_size - assert_close(our_outputs, ref_outputs) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index bda00e08190ef..1e163fbf97ce3 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Optional +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union import torch @@ -106,6 +107,31 @@ def assert_close(a, b): torch.testing.assert_close(a, b, rtol=rtol, atol=atol) +@dataclass +class PunicaTensors: + inputs_tensor: torch.Tensor + lora_weights: Union[torch.Tensor, List[torch.Tensor]] + our_out_tensor: torch.Tensor + ref_out_tensor: torch.Tensor + b_seq_start_loc: torch.Tensor + prompt_lora_mapping: torch.Tensor + seq_len_tensor: torch.Tensor + token_lora_mapping: torch.Tensor + + def meta(self) -> Tuple[int, int]: + """ + Infer max_seq_length and token_nums from the tensors + and return them. + """ + max_seq_length = self.seq_len_tensor.max() + token_nums = self.seq_len_tensor.sum().item() + if isinstance(max_seq_length, tuple): + max_seq_length = max_seq_length[0].item() + else: + max_seq_length = max_seq_length.item() + return max_seq_length, token_nums + + def generate_data( batches, hidden_size, @@ -115,7 +141,7 @@ def generate_data( dtype, op_type, device, -): +) -> PunicaTensors: seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches, )).to(device) b_seq_start_loc = torch.cumsum( @@ -164,7 +190,8 @@ def generate_data( indices[current_offset:current_offset + seq_len_tensor[b_id]].copy_(lora_index) current_offset += seq_len_tensor[b_id].item() - return ( + + return PunicaTensors( inputs_tensor, lora_weights, our_out_tensor, @@ -185,7 +212,7 @@ def generate_data_for_expand_nslices( dtype, nslices, device, -): +) -> PunicaTensors: seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches, )).to(device) b_seq_start_loc = torch.cumsum( @@ -222,7 +249,7 @@ def generate_data_for_expand_nslices( current_offset += seq_len_tensor[b_id].item() lora_indices_tensor = lora_indices_tensor.to(device) - return ( + return PunicaTensors( inputs_tensor, lora_weights_lst, our_out_tensor, @@ -244,7 +271,7 @@ def generate_data_for_nslices( dtype, op_type, device, -): +) -> PunicaTensors: seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches, )).to(device) b_seq_start_loc = torch.cumsum( @@ -302,7 +329,7 @@ def generate_data_for_nslices( current_offset += seq_len_tensor[b_id].item() lora_indices_tensor = lora_indices_tensor.to(device) - return ( + return PunicaTensors( inputs_tensor, lora_weights_lst, our_out_tensor, diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index fe9361d126120..d1f643a8fdb73 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -15,7 +15,7 @@ from ....utils import RemoteOpenAIServer from ...utils import check_logprobs_close -MODEL_NAME = "fixie-ai/ultravox-v0_3" +MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" AudioTuple = Tuple[np.ndarray, int] diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_hybrid.py similarity index 91% rename from tests/models/decoder_only/language/test_jamba.py rename to tests/models/decoder_only/language/test_hybrid.py index cc98f1d7b5ce8..a39b11923582c 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -8,7 +8,8 @@ from ...utils import check_outputs_equal -MODELS = ["ai21labs/Jamba-tiny-dev"] +# This test is for the hybrid models +MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"] @pytest.mark.parametrize("model", MODELS) @@ -23,6 +24,10 @@ def test_models( max_tokens: int, ) -> None: + # numeric error produces different generation + if 'Bamba' in model: + example_prompts.pop(3) + with hf_runner( model, dtype=dtype, @@ -108,15 +113,21 @@ def test_mamba_prefill_chunking_with_parallel_sampling( @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [10]) +@pytest.mark.parametrize("max_tokens", [7]) def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int) -> None: # numeric error during prefill chucking produces different generation # compared to w/o prefill chunking for those examples, removed them for now - example_prompts.pop(7) - example_prompts.pop(2) - example_prompts.pop(1) + if 'Jamba' in model: + example_prompts.pop(7) + example_prompts.pop(2) + example_prompts.pop(1) + elif 'Bamba' in model: + example_prompts.pop(6) + example_prompts.pop(3) + example_prompts.pop(2) + dtype = "half" # use a different dtype for Bamba with hf_runner( model, @@ -145,7 +156,7 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [15]) def test_parallel_sampling( vllm_runner, @@ -249,17 +260,17 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( dtype: str, example_prompts, ) -> None: - # This test is for verifying that the Jamba inner state management doesn't + # This test is for verifying that the hybrid inner state management doesn't # collapse in case where the number of incoming requests and # finished_requests_ids is larger than the maximum mamba block capacity. - # This could generally happen due to the fact that Jamba does support + # This could generally happen due to the fact that hybrid does support # statelessness mechanism where it can cleanup new incoming requests in # a single step. try: with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: vllm_model.generate_greedy([example_prompts[0]] * 100, 10) except ValueError: - pytest.fail("Jamba inner state wasn't cleaned up properly between" + pytest.fail("Hybrid inner state wasn't cleaned up properly between" "steps finished requests registered unnecessarily ") @@ -271,14 +282,14 @@ def test_state_cleanup( dtype: str, example_prompts, ) -> None: - # This test is for verifying that the Jamba state is cleaned up between + # This test is for verifying that the Hybrid state is cleaned up between # steps, If its not cleaned, an error would be expected. try: with vllm_runner(model, dtype=dtype) as vllm_model: for _ in range(10): vllm_model.generate_greedy([example_prompts[0]] * 100, 1) except ValueError: - pytest.fail("Jamba inner state wasn't cleaned up between states, " + pytest.fail("Hybrid inner state wasn't cleaned up between states, " "could be related to finished_requests_ids") @@ -324,7 +335,7 @@ def test_multistep_correctness(vllm_runner, model: str, dtype: str, @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [64]) -def test_jamba_distributed_produces_identical_generation( +def test_hybrid_distributed_produces_identical_generation( vllm_runner, model: str, dtype: str, max_tokens: int, example_prompts) -> None: diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index 1ad56241535b8..c6d5244318a32 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -26,6 +26,9 @@ "google/gemma-1.1-2b-it", # gemma marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), + pytest.param( + "THUDM/chatglm3-6b", # ChatGLM (text-only) + ), pytest.param( "meta-llama/Llama-3.2-1B-Instruct", # llama marks=[pytest.mark.core_model, pytest.mark.cpu_model], @@ -43,6 +46,9 @@ "microsoft/phi-2", # phi marks=[pytest.mark.core_model], ), + pytest.param( + "Qwen/Qwen-7B", # qwen (text-only) + ), pytest.param( "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 marks=[pytest.mark.core_model], @@ -68,6 +74,10 @@ def test_models( ) -> None: with hf_runner(model, dtype=dtype) as hf_model: + if model.startswith("THUDM/chatglm3"): + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.transformer.output_layer + hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 77cf3442df905..6244056c7474a 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -89,7 +89,7 @@ def _test_processing_correctness( mm_data = { k: [(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]()) - for _ in range(rng.randint(limit))] + for _ in range(rng.randint(limit + 1))] for k, limit in limit_mm_per_prompt.items() } @@ -147,6 +147,7 @@ def _test_processing_correctness( "facebook/chameleon-7b", "deepseek-ai/deepseek-vl2-tiny", "adept/fuyu-8b", + "THUDM/glm-4v-9b", "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", "HuggingFaceM4/Idefics3-8B-Llama3", @@ -163,7 +164,7 @@ def _test_processing_correctness( "Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", - "fixie-ai/ultravox-v0_3", + "fixie-ai/ultravox-v0_5-llama-3_2-1b", ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) diff --git a/tests/models/registry.py b/tests/models/registry.py index 20787fe008aa8..66b7d3c2e77b5 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -102,6 +102,7 @@ def check_available_online( trust_remote_code=True), "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True), + "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B"), "BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"), # ChatGLMModel supports multimodal "CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01", @@ -266,7 +267,7 @@ def check_available_online( "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 min_transformers_version="4.49"), # noqa: E501 - "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3", + "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", trust_remote_code=True), # [Encoder-decoder] "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 diff --git a/tests/multimodal/utils.py b/tests/multimodal/utils.py index 9a336b7e60ffc..40fcfeeeac7d0 100644 --- a/tests/multimodal/utils.py +++ b/tests/multimodal/utils.py @@ -17,10 +17,7 @@ def random_video( min_wh: int, max_wh: int, ): - # Temporary workaround for https://github.com/huggingface/transformers/issues/35412 num_frames = rng.randint(min_frames, max_frames) - num_frames = (num_frames // 2) * 2 - w, h = rng.randint(min_wh, max_wh, size=(2, )) return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 5616935ebdc0c..3a7f0a196b5b8 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -55,10 +55,21 @@ def check_model(model): assert isinstance(attn.quant_method, Fp8KVCacheMethod) - # NOTE: it is valid for scales to be 1.0 (default value), but - # we know these checkpoints have scales < 1.0 - assert 0.0 < attn._k_scale < 1.0 - assert 0.0 < attn._v_scale < 1.0 + if not current_platform.is_rocm(): + # NOTE: This code path requires validation on Non-CUDA platform + # NOTE: it is valid for scales to be 1.0 (default value), but + # we know these checkpoints have scales < 1.0 + assert 0.0 < attn._k_scale < 1.0 + assert 0.0 < attn._v_scale < 1.0 + else: + # NOTE: This code path is for ROCm platform + # NOTE: it is valid for scales to be 1.0 (default value), but + # we know these checkpoints have scales < 1.0 + # However on ROCm platform, the _k_scale and _v_scale will be + # scaled by a factor of 2 as described in + # vllm/model_executor/layers/quantization/kv_cache.py + assert 0.0 < attn._k_scale < (1.0 * 2.0) + assert 0.0 < attn._v_scale < (1.0 * 2.0) llm.apply_model(check_model) @@ -91,13 +102,29 @@ def check_model(model): assert attn._k_scale == 1.0 assert attn._v_scale == 1.0 - if current_platform.has_device_capability(89) and not force_marlin: - # For GPUs with hardware support, we keep weights in fp8 - assert fc1.weight.dtype == torch.float8_e4m3fn - else: - # For GPUs without hardware support, we pack the fp8 weights - # for weight-only quantization using Marlin kernels - assert fc1.weight.dtype == torch.int32 + if current_platform.is_cuda(): + if current_platform.has_device_capability( + 89) and not force_marlin: + # For GPUs with hardware support, we keep weights in fp8 + assert fc1.weight.dtype == torch.float8_e4m3fn + else: + # For GPUs without hardware support, we pack the fp8 weights + # for weight-only quantization using Marlin kernels + assert fc1.weight.dtype == torch.int32 + elif current_platform.is_rocm(): + # Only MI300 and above support quantization='fp8' + if current_platform.has_device_capability( + 94) and not force_marlin: + # For GPUs with hardware support, we keep weights in fp8 + assert fc1.weight.dtype == torch.float8_e4m3fnuz + else: # unsupported ROCm platform + pytest.skip( + "Skip `test_load_fp16_model`. " + "It only runs on ROCm platform with FP8 compute." + " e.g. MI300X and above.") + else: # unsupported platform + pytest.skip("Skip `test_load_fp16_model`. " + "It only runs on CUDA and ROCm platform.") llm.apply_model(check_model) diff --git a/tests/quantization/test_ptpc_fp8.py b/tests/quantization/test_ptpc_fp8.py new file mode 100644 index 0000000000000..9bbb5e327968f --- /dev/null +++ b/tests/quantization/test_ptpc_fp8.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests whether PTPC w8a8 FP8 computation is enabled correctly. + +Run `pytest tests/quantization/test_ptpc_fp8.py --forked`. +""" +import pytest +import torch + +from tests.quantization.utils import is_quant_method_supported +from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod +from vllm.model_executor.layers.quantization.ptpc_fp8 import ( + PTPCFp8LinearMethod) +from vllm.platforms import current_platform + + +@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"), + reason="PTPC FP8 is not supported on this GPU type.") +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="This test is for ROCm GPU.") +@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"]) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"]) +def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None: + + try: + with vllm_runner("facebook/opt-125m", + dtype=dtype, + quantization="ptpc_fp8", + kv_cache_dtype=kv_cache_dtype) as llm: + + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + fc1 = model.model.decoder.layers[0].fc1 + assert isinstance(fc1.quant_method, PTPCFp8LinearMethod) + if kv_cache_dtype == "ptpc_fp8": + attn = model.model.decoder.layers[0].self_attn.attn + assert isinstance(attn.quant_method, Fp8KVCacheMethod) + assert attn._k_scale == 1.0 + assert attn._v_scale == 1.0 + + if current_platform.has_device_capability(94): + # For GPUs with hardware support, we keep weights in fp8 + assert fc1.weight.dtype == torch.float8_e4m3fnuz + else: + pytest.skip() + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output + except AssertionError as e: + if str( + e + ) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501 + # If the error message matches, the test passes + pass + else: + # If the error message does not match, re-raise the exception + raise diff --git a/tests/test_seed_behavior.py b/tests/test_seed_behavior.py new file mode 100644 index 0000000000000..7e4e71563e7d3 --- /dev/null +++ b/tests/test_seed_behavior.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +import random + +import numpy as np +import torch + +from vllm.platforms.interface import Platform + + +def test_seed_behavior(): + # Test with seed=None + Platform.seed_everything(None) + random_value_1 = random.randint(0, 100) + np_random_value_1 = np.random.randint(0, 100) + torch_random_value_1 = torch.randint(0, 100, (1, )).item() + + Platform.seed_everything(None) + random_value_2 = random.randint(0, 100) + np_random_value_2 = np.random.randint(0, 100) + torch_random_value_2 = torch.randint(0, 100, (1, )).item() + + assert random_value_1 != random_value_2 + assert np_random_value_1 != np_random_value_2 + assert torch_random_value_1 != torch_random_value_2 + + # Test with a specific seed + Platform.seed_everything(42) + random_value_3 = random.randint(0, 100) + np_random_value_3 = np.random.randint(0, 100) + torch_random_value_3 = torch.randint(0, 100, (1, )).item() + + Platform.seed_everything(42) + random_value_4 = random.randint(0, 100) + np_random_value_4 = np.random.randint(0, 100) + torch_random_value_4 = torch.randint(0, 100, (1, )).item() + + assert random_value_3 == random_value_4 + assert np_random_value_3 == np_random_value_4 + assert torch_random_value_3 == torch_random_value_4 diff --git a/tests/tokenization/test_mistral_tokenizer.py b/tests/tokenization/test_mistral_tokenizer.py new file mode 100644 index 0000000000000..03e1f1fadd731 --- /dev/null +++ b/tests/tokenization/test_mistral_tokenizer.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from mistral_common.protocol.instruct.messages import UserMessage +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.protocol.instruct.tool_calls import Function, Tool + +from vllm.transformers_utils.tokenizers.mistral import ( + make_mistral_chat_completion_request) + + +# yapf: enable +@pytest.mark.parametrize( + "openai_request,expected_mistral_request", + [( + { + "messages": [{ + "role": "user", + "content": "What is the current local date and time?", + }], + "tools": [{ + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + }, + }], + }, + ChatCompletionRequest( + messages=[ + UserMessage(content="What is the current local date and time?") + ], + tools=[ + Tool( + type="function", + function=Function( + name="get_current_time", + description="Fetch the current local date and time.", + parameters={}, + ), + ) + ], + ), + )], +) +def test_make_mistral_chat_completion_request(openai_request, + expected_mistral_request): + assert (make_mistral_chat_completion_request( + openai_request["messages"], + openai_request["tools"]) == expected_mistral_request) diff --git a/tests/utils.py b/tests/utils.py index 3b32052fe4c87..f39cbe7ede030 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -46,8 +46,9 @@ def _nvml(): finally: amdsmi_shut_down() elif current_platform.is_cuda(): - from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, - nvmlInit, nvmlShutdown) + from vllm.third_party.pynvml import (nvmlDeviceGetHandleByIndex, + nvmlDeviceGetMemoryInfo, nvmlInit, + nvmlShutdown) @contextmanager def _nvml(): diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 8df4cbe1be71b..ba08b83ec54e5 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -5,10 +5,11 @@ from vllm.multimodal.inputs import MultiModalKwargs from vllm.sampling_params import SamplingParams from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, + KVCacheBlock, PrefixCachingMetrics, generate_block_hash_extra_keys, hash_block_tokens, hash_request_tokens) +from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -277,3 +278,39 @@ def test_hash_request_tokens_no_mm_inputs(): assert block_hashes[0].extra_keys is None assert block_hashes[1].token_ids == (3, 4, 5) assert block_hashes[1].extra_keys is None + + +def test_metrics(): + """ + Test the prefix caching metrics. + """ + + def stats(requests, queries, hits): + return PrefixCacheStats(requests=requests, queries=queries, hits=hits) + + metrics = PrefixCachingMetrics(interval=5) + assert metrics.hit_rate == 0.0 + + metrics.observe(stats(1, 20, 9)) + # 9 / 20 = 0.45 + assert metrics.hit_rate == 0.45 + + metrics.observe(stats(4, 80, 16)) + + # 25 / 100 = 0.25 + assert metrics.hit_rate == 0.25 + + metrics.observe(stats(1, 10, 2)) + + # Remove (20, 9) and add (10, 2): 18 / 90 = 0.2 + assert metrics.aggregated_requests == 5 + assert metrics.aggregated_query_total == 90 + assert metrics.aggregated_query_hit == 18 + assert metrics.hit_rate == 0.2 + + metrics.reset() + assert metrics.hit_rate == 0.0 + assert metrics.aggregated_requests == 0 + assert metrics.aggregated_query_total == 0 + assert metrics.aggregated_query_hit == 0 + assert not metrics.query_queue diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index a6c0162d3f308..d598d12571f12 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -51,7 +51,7 @@ def test_prefill(): all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(req0.kv_block_hashes) == 3 + assert len(manager.req_to_block_hashes[req0.request_id]) == 3 assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) @@ -76,7 +76,7 @@ def test_prefill(): unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(req1.kv_block_hashes) == 3 + assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert [b.block_id for b in computed_blocks] == [0, 1, 2] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -107,7 +107,7 @@ def test_prefill(): unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(req2.kv_block_hashes) == 3 + assert len(manager.req_to_block_hashes[req2.request_id]) == 3 assert [b.block_id for b in computed_blocks] == [0, 1, 2] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -494,10 +494,11 @@ def test_mm_prefix_caching(): # Completed block should have hashes with extra keys. assert not computed_blocks assert num_computed_tokens == 0 - assert len(req0.kv_block_hashes) == 3 - assert req0.kv_block_hashes[0].extra_keys == ("aaa", ) - assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb") - assert req0.kv_block_hashes[2].extra_keys == ("bbb", ) + block_hashes = manager.req_to_block_hashes[req0.request_id] + assert len(block_hashes) == 3 + assert block_hashes[0].extra_keys == ("aaa", ) + assert block_hashes[1].extra_keys == ("aaa", "bbb") + assert block_hashes[2].extra_keys == ("bbb", ) blocks = manager.allocate_slots(req0, 59, computed_blocks) assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] @@ -510,8 +511,8 @@ def test_mm_prefix_caching(): assert new_blocks is not None and len(new_blocks) == 0 # The just completed block should have hashes with extra keys. - assert len(req0.kv_block_hashes) == 4 - assert req0.kv_block_hashes[3].extra_keys == ("ccc", ) + assert len(block_hashes) == 4 + assert block_hashes[3].extra_keys == ("ccc", ) # Cache hit. unique_token_ids = [-1] * 7 + [200] * 5 @@ -613,7 +614,7 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req1 = make_request("1", all_token_ids) computed_blocks, _ = manager.get_computed_blocks(req1) - assert len(req1.kv_block_hashes) == 3 + assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(computed_blocks) == 3 blocks = manager.allocate_slots(req1, 7, computed_blocks) assert [b.block_id for b in blocks] == [4] diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 8eb08f3e842ca..0d29729a454cf 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -195,8 +195,8 @@ def test_schedule_partial_requests(): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[0] * len(requests), - logprob_token_ids_cpu=None, - logprobs_cpu=None, + logprobs=None, + prompt_logprobs_dict={}, ) scheduler.update_from_output(output, model_runner_output) diff --git a/tests/v1/engine/conftest.py b/tests/v1/engine/conftest.py new file mode 100644 index 0000000000000..560dc31218522 --- /dev/null +++ b/tests/v1/engine/conftest.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Tuple + +import pytest +import torch +from transformers import AutoTokenizer + +from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, + NUM_SAMPLE_LOGPROBS_UNDER_TEST, PROMPT_LEN, + TOKENIZER_NAME, + DummyOutputProcessorTestVectors, + generate_dummy_prompt_logprobs_tensors, + generate_dummy_sample_logprobs) +from vllm.engine.arg_utils import EngineArgs +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + +from tests.v1.engine.utils import FULL_STRINGS # isort: skip + +EngineCoreSampleLogprobsType = List[Tuple[torch.Tensor, torch.Tensor]] +EngineCorePromptLogprobsType = Tuple[torch.Tensor, torch.Tensor] + + +def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: + """Generate output processor dummy test vectors, without logprobs + + Returns: + DummyOutputProcessorTestVectors instance with no logprobs + """ + + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) + vllm_config = EngineArgs(model=TOKENIZER_NAME).create_engine_config() + # Tokenize prompts under test & create dummy generated tokens + prompt_tokens = [ + tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS + ] + generation_tokens = [ + tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS + ] + # Generate prompt strings + prompt_strings = [ + tokenizer.decode(prompt_tokens, skip_special_tokens=True) + for prompt_tokens in prompt_tokens + ] + prompt_strings_len = [ + len(prompt_string) for prompt_string in prompt_strings + ] + return DummyOutputProcessorTestVectors( + tokenizer=tokenizer, + tokenizer_group=init_tokenizer_from_configs( + vllm_config.model_config, vllm_config.scheduler_config, + vllm_config.parallel_config, vllm_config.lora_config), + vllm_config=vllm_config, + full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS], + prompt_tokens=prompt_tokens, + generation_tokens=generation_tokens, + prompt_strings=prompt_strings, + prompt_strings_len=prompt_strings_len, + generation_strings=[ + text[prompt_len:] + for text, prompt_len in zip(FULL_STRINGS, prompt_strings_len) + ], + prompt_logprobs=[], + generation_logprobs=[]) + + +@pytest.fixture +def dummy_test_vectors() -> DummyOutputProcessorTestVectors: + """Generate output processor dummy test vectors, with logprobs + + Returns: + DummyOutputProcessorTestVectors instance with logprobs + """ + # Build dummy test vectors without logprobs + dtv = _build_test_vectors_no_logprobs() + # Inject logprobs into dummy test vectors + # data structure + dtv.generation_logprobs = [ + generate_dummy_sample_logprobs( + sampled_tokens_list=tokens_list, + num_logprobs=NUM_SAMPLE_LOGPROBS_UNDER_TEST, + tokenizer=dtv.tokenizer) for tokens_list in dtv.generation_tokens + ] + dtv.prompt_logprobs = [ + generate_dummy_prompt_logprobs_tensors( + prompt_tokens_list=tokens_list, + num_logprobs=NUM_PROMPT_LOGPROBS_UNDER_TEST, + tokenizer=dtv.tokenizer) for tokens_list in dtv.prompt_tokens + ] + return dtv diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 4b5bc9ced3733..94e18289e3c7f 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -2,10 +2,11 @@ import asyncio from contextlib import ExitStack -from typing import List, Tuple +from typing import List, Optional, Tuple import pytest +from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.platforms import current_platform @@ -21,13 +22,19 @@ disable_log_requests=True) -async def generate(engine: AsyncLLM, request_id: str, +async def generate(engine: AsyncLLM, + request_id: str, output_kind: RequestOutputKind, - max_tokens: int) -> Tuple[int, str]: + max_tokens: int, + prompt_logprobs: Optional[int] = None) -> Tuple[int, str]: + # Ensure generate doesn't complete too fast for cancellation test. + await asyncio.sleep(0.2) + count = 0 sampling_params = SamplingParams(max_tokens=max_tokens, output_kind=output_kind, - temperature=0) + temperature=0, + prompt_logprobs=prompt_logprobs) async for out in engine.generate(request_id=request_id, prompt="Hello my name is Robert and", sampling_params=sampling_params): @@ -43,6 +50,40 @@ async def generate(engine: AsyncLLM, request_id: str, return count, request_id +@pytest.mark.parametrize( + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.asyncio +async def test_async_llm_refuses_prompt_logprobs_with_apc( + monkeypatch, output_kind: RequestOutputKind): + """Test passes if AsyncLLM raises an exception when it is configured + for automatic prefix caching and it receives a request with + prompt_logprobs enabled, which is incompatible.""" + # TODO(rickyx): Remove monkeypatch VLLM_USE_V1 setting once we have a + # better way to test V1 so that in the future when we switch, we don't + # have to change all the tests. + monkeypatch.setenv("VLLM_USE_V1", "1") + # Create AsyncLLM engine with APC + apc_engine_args = AsyncEngineArgs(model="facebook/opt-125m", + enable_prefix_caching=True, + gpu_memory_utilization=0.8, + disable_log_requests=True) + engine = AsyncLLM.from_engine_args(apc_engine_args) + try: + with pytest.raises(ValueError) as excinfo: + # Issue a request with prompt logprobs enabled, which should fail + await asyncio.create_task( + generate(engine, + "request-0", + output_kind, + 10, + prompt_logprobs=5)) + # Validate exception string is correct + assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG + finally: + # Shut down engine + engine.shutdown() + + @pytest.mark.parametrize( "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) @pytest.mark.asyncio diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py new file mode 100644 index 0000000000000..84b634316cb46 --- /dev/null +++ b/tests/v1/engine/test_llm_engine.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG +from vllm import LLM, SamplingParams + + +def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch): + """Test passes if LLMEngine raises an exception when it is configured + for automatic prefix caching and it receives a request with + prompt_logprobs enabled, which is incompatible.""" + + monkeypatch.setenv("VLLM_USE_V1", "1") + # TODO(nick): Single-proc to work around a ZMQ shutdown hang for now. + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + with pytest.raises(ValueError) as excinfo: + LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate( + "Hello, my name is", + SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5)) + + # Validate exception string is correct + assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 5782a249f3627..c8f43edb70b3a 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -1,82 +1,47 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List +import math +from typing import Dict, List, Optional import pytest -from transformers import AutoTokenizer -from vllm.engine.arg_utils import EngineArgs +from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, + NUM_SAMPLE_LOGPROBS_UNDER_TEST, + STOP_STRINGS, + DummyOutputProcessorTestVectors, + MockEngineCore) from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest +from vllm.sequence import PromptLogprobs, SampleLogprobs +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.output_processor import OutputProcessor -TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3" -VLLM_CONFIG = EngineArgs(model=TOKENIZER_NAME).create_engine_config() -TOKENIZER_GROUP = init_tokenizer_from_configs(VLLM_CONFIG.model_config, - VLLM_CONFIG.scheduler_config, - VLLM_CONFIG.parallel_config, - VLLM_CONFIG.lora_config) -tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) - -FULL_STRINGS = [ - "My name is Robert from Neural Magic and I love working on vLLM so much!", - "Red Hat is the best open source company by far across Linux, K8s, and AI.", - "Nick is the name of my brother in addition to my colleague from Red Hat.", -] - -STOP_STRINGS = ["I love working on", "company by far", "brother in"] - -FULL_TOKENS = [tokenizer(text).input_ids for text in FULL_STRINGS] -PROMPT_LEN = 5 -PROMPT_TOKENS = [ - tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS -] -GENERATION_TOKENS = [ - tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS -] -PROMPT_STRINGS = [ - tokenizer.decode(prompt_tokens, skip_special_tokens=True) - for prompt_tokens in PROMPT_TOKENS -] -PROMPT_STRINGS_LEN = [len(prompt_string) for prompt_string in PROMPT_STRINGS] -GENERATION_STRINGS = [ - text[prompt_len:] - for text, prompt_len in zip(FULL_STRINGS, PROMPT_STRINGS_LEN) -] - - -class MockEngineCore: - """Mock outputs form premade tokens lists.""" - - def __init__(self, tokens_list: List[List[int]]): - self.tokens_list = tokens_list - self.current_idx = 0 - - def get_outputs(self) -> List[EngineCoreOutput]: - token_idx = self.current_idx - self.current_idx += 1 - - outputs = [] - for req_idx, token_ids in enumerate(self.tokens_list): - if len(token_ids) > token_idx: - output = EngineCoreOutput(request_id=f"request-{req_idx}", - new_token_ids=[token_ids[token_idx]], - finished=False) - if token_idx == len(token_ids) - 1: - output.finished = True - output.finish_reason = "stopped" - outputs.append(output) - - return outputs + +def _ref_convert_id_to_token( + tokenizer: AnyTokenizer, + token_id: int, +) -> str: + """Reference impl of logprobs detokenization. + + Args: + tokenizer: tokenizer used by the model under test + token_id: convert this token id + + Returns: + String representation of input token id + """ + return tokenizer.convert_ids_to_tokens(token_id) or "" @pytest.mark.parametrize( "request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) -def test_incremental_detokenization(request_output_kind: RequestOutputKind): - output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=False) - engine_core = MockEngineCore(GENERATION_TOKENS) +def test_incremental_detokenization(request_output_kind: RequestOutputKind, + dummy_test_vectors): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + log_stats=False) + engine_core = MockEngineCore( + tokens_list=dummy_test_vectors.generation_tokens) # Make N requests. requests = [ @@ -94,10 +59,10 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): spaces_between_special_tokens=False, output_kind=request_output_kind, stop=[], - include_stop_str_in_output=False)) - for idx, ( - prompt, - prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) + include_stop_str_in_output=False, + )) for idx, (prompt, prompt_tokens) in enumerate( + zip(dummy_test_vectors.prompt_strings, + dummy_test_vectors.prompt_tokens)) ] # Add requests to the detokenizer. @@ -113,7 +78,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): break # Step the Detokenizer. - processed_outputs = output_processor.process_outputs(outputs, ) + processed_outputs = output_processor.process_outputs(outputs) request_outputs = processed_outputs.request_outputs requests_to_abort = processed_outputs.reqs_to_abort assert len(requests_to_abort) == 0 @@ -132,7 +97,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): # Confirmed tracked values matches what we expected. for idx, (ref_gen_str, ref_gen_toks) in enumerate( - zip(GENERATION_STRINGS, GENERATION_TOKENS)): + zip(dummy_test_vectors.generation_strings, + dummy_test_vectors.generation_tokens)): gen_str = gen_strings[f"request-{idx}"] gen_toks = gen_tokens[f"request-{idx}"] @@ -143,15 +109,390 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): assert not output_processor.has_unfinished_requests() +def _validate_logprobs( + gen_tokens: Dict[str, List[int]], + gen_logprobs: Dict[str, Optional[SampleLogprobs]], + gen_prompt_logprobs: Dict[str, Optional[PromptLogprobs]], + gen_cumulative_logprob: Dict[str, float], + dtv: DummyOutputProcessorTestVectors, + request_id_list: List[str], + num_sample_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], +) -> None: + for req_idx, req_id in enumerate(request_id_list): + new_tokens = gen_tokens[req_id] + logprobs = gen_logprobs[req_id] + prompt_logprobs = gen_prompt_logprobs[req_id] + cumulative_logprob = gen_cumulative_logprob[req_id] + prompt_token_ids = dtv.prompt_tokens[req_idx] + ref_logprobs = dtv.generation_logprobs[req_idx] + ref_prompt_logprobs = dtv.prompt_logprobs[req_idx] + if num_sample_logprobs is not None: + # Validate sample logprobs + assert logprobs is not None, (f"Request {req_id} requires sample" + " logprobs but sample logprobs are" + " None.") + # Require num sampled tokens to match num + # sampled logprobs - especially important + # to check since the detokenizer can cause + # a request to finish early due to a stop + # string being hit + num_new_tokens = len(new_tokens) + len_sample_logprobs = len(logprobs) + assert num_new_tokens == len_sample_logprobs, ( + f"Request {req_id} has {num_new_tokens}" + " completion tokens but has" + f" {len_sample_logprobs} sample logprobs.") + ref_cumulative_logprob = 0.0 + for idx, (sampled_token, + pos_logprob_dict) in enumerate(zip(new_tokens, + logprobs)): + # Break out the reference log probability value & + # logprob token id tensors associated with this + # position in the completion. Also break out the + # sampled token ranks + (ref_pos_logprob_toks, ref_pos_logprob_vals, + ref_sampled_token_rank) = ref_logprobs[idx] + # For each position in the completion sequence, + # ensure the actual sampled token is among the + # logprobs + assert sampled_token in pos_logprob_dict, ( + f"Sampled token {sampled_token} not" + f" present in logprob at index {idx}") + + # Validate number of sample logprobs + num_lp_toks = len(pos_logprob_dict) + assert (num_lp_toks == num_sample_logprobs + or num_lp_toks == num_sample_logprobs + + 1), ("Valid numbers of sample logprobs are" + f" {num_sample_logprobs} or" + f" {num_sample_logprobs+1} but" + f" {num_lp_toks} logprobs found at" + f" position {idx}. Logprobs dict:" + f" {pos_logprob_dict}") + + # Validate sampled token logprob rank + smp_lp = pos_logprob_dict[sampled_token] + smp_lp_rank = smp_lp.rank + assert (ref_sampled_token_rank == smp_lp_rank), ( + "Sampled token logprob rank" + f" {smp_lp_rank} does not match" + " correct value" + f" {ref_sampled_token_rank}" + f" in Logprob {smp_lp}") + + # Validate that the logprob processor yields + # the correct log probabilities and valid + # rankings + rank_one_appears = False + for jdx in range(1, len(ref_pos_logprob_toks)): + # Iterate over the (logprob val,logprob tok id) + # pairs expected by the test fixture at this + # position in the completion. + ref_lp_val = ref_pos_logprob_vals[jdx] + ref_tok_id = ref_pos_logprob_toks[jdx] + assert ref_tok_id in pos_logprob_dict, ( + f"Expected token {ref_tok_id} to be" + f" in logprob dict but it is not.") + + # Extract actually-generated logprob + # info + lp = pos_logprob_dict[ref_tok_id] + lp_val = lp.logprob + lp_rank = lp.rank + + # A "top" (rank 1) logprob must be + # present + rank_one_appears = (True + if lp_rank == 1 else rank_one_appears) + + # Rank must be >= 1 + assert lp_rank >= 1, (f"Logprob {lp} has invalid" + f" rank {lp_rank} < 1." + f" Logprob dict: {pos_logprob_dict}") + + # Validate log probability + assert math.isclose(lp_val, ref_lp_val), ( + f"Token id {ref_tok_id} appears in logprobs dict" + f" at position {idx} in completion with log" + f" probability {lp_val} but {ref_lp_val} was" + f" expected. Logprob: {lp}") + + assert rank_one_appears, (f"No Logprob has rank 1" + " in the following Logprob" + f" dict: {pos_logprob_dict}") + + # Validate logprobs detokenization + for lp_tok in pos_logprob_dict: + # Confirm that sample logprob decoded token matches + # the logprob token id at this sequence position + decoded_token = pos_logprob_dict[lp_tok].decoded_token + ref_decoded_token = _ref_convert_id_to_token( + dtv.tokenizer, lp_tok) + assert decoded_token == ref_decoded_token, ( + f"Sampled logprob token id {lp_tok} decodes to" + f" {ref_decoded_token} but Logprob decoded" + f" token is {decoded_token} instead" + f" (at position {idx})") + + ref_cumulative_logprob += pos_logprob_dict[ + sampled_token].logprob + # Assert that cumulative logprobs are correct + assert math.isclose(cumulative_logprob, ref_cumulative_logprob) + else: + # Sample logprobs disabled for this request + assert logprobs is None + assert cumulative_logprob is None + + if num_prompt_logprobs is not None: + # Validate prompt logprobs + assert prompt_logprobs is not None, ( + f"Request {req_id} requires prompt" + " logprobs but prompt logprobs are" + " None.") + # Require num prompt tokens to match num + # prompt logprobs + num_prompt_tokens = len(prompt_token_ids) + len_prompt_logprobs = len(prompt_logprobs) + assert num_prompt_tokens == len_prompt_logprobs, ( + f"Request {req_id} has {num_prompt_tokens}" + " prompt tokens but has" + f" {len_prompt_logprobs} prompt logprobs.") + # First prompt logprob is None + first_plp_dict = prompt_logprobs[0] + assert first_plp_dict is None, ( + f"Request {req_id} first prompt logprob" + f" should be None but has following value" + f" instead: {first_plp_dict}") + # Break out the reference prompt log prob value & + # logprob token id matrices for the whole prompt. + # Also break out the prompt token rank vector + (ref_prompt_logprob_toks, ref_prompt_logprob_vals, + ref_prompt_token_ranks) = ref_prompt_logprobs + for idx, (prompt_token, pos_logprob_dict) in enumerate( + zip(prompt_token_ids[1:], prompt_logprobs[1:])): + + # Break out the reference prompt log prob value + # vector, prompt logprob token id vector, and + # prompt token rank at the current position. + (ref_pos_prompt_logprob_toks, ref_pos_prompt_logprob_vals, + ref_pos_prompt_token_rank) = (ref_prompt_logprob_toks[idx, :], + ref_prompt_logprob_vals[idx, :], + ref_prompt_token_ranks[idx]) + + # For each position in the prompt sequence, + # ensure the actual prompt token is among the + # logprobs + assert prompt_token in pos_logprob_dict, ( + f"Prompt token {prompt_token} not" + f" present in logprob at index {idx}") + # Validate number of prompt logprobs + num_plp_toks = len(pos_logprob_dict) + assert (num_plp_toks == num_prompt_logprobs + or num_plp_toks == num_prompt_logprobs + + 1), ("Valid numbers of prompt logprobs are" + f" {num_prompt_logprobs} or" + f" {num_prompt_logprobs+1} but" + f" {num_plp_toks} logprobs found at" + f" position {idx}. Logprobs dict:" + f" {pos_logprob_dict}") + + # Validate prompt token logprob rank + prmpt_tok_lp = pos_logprob_dict[prompt_token] + prmpt_tok_lp_rank = prmpt_tok_lp.rank + ref_prmpt_tok_lp_rank = ref_pos_prompt_token_rank + assert (ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank), ( + "Prompt token logprob rank" + f" {prmpt_tok_lp_rank} does not match" + " correct value" + f" {ref_prmpt_tok_lp_rank}" + f" in Logprob {prmpt_tok_lp}") + + # Validate that the logprob processor yields + # the correct prompt log probs and valid + # rankings + rank_one_appears = False + for jdx in range(1, len(ref_pos_prompt_logprob_toks)): + # Iterate over the (logprob val,logprob tok id) + # pairs expected by the test fixture at this + # position in the completion. + ref_plp_val = float(ref_pos_prompt_logprob_vals[jdx]) + ref_tok_id = int(ref_pos_prompt_logprob_toks[jdx]) + assert ref_tok_id in pos_logprob_dict, ( + f"Expected token {ref_tok_id} to be" + f" in logprob dict but it is not.") + + # Extract actually-generated logprob + # info + plp = pos_logprob_dict[ref_tok_id] + plp_val = plp.logprob + plp_rank = plp.rank + + # A "top" (rank 1) logprob must be + # present + rank_one_appears = (True + if plp_rank == 1 else rank_one_appears) + + # Rank must be >= 1 + assert plp_rank >= 1, ( + f"Logprob {plp} has invalid" + f" rank {plp_rank} < 1." + f" Logprob dict: {pos_logprob_dict}") + + # Validate log probability + assert math.isclose(plp_val, ref_plp_val), ( + f"Token id {ref_tok_id} appears in logprobs dict" + f" at position {idx} in completion with log" + f" probability {plp_val} but {ref_plp_val} was" + f" expected. Logprob: {plp}") + + assert rank_one_appears, (f"No Logprob has rank 1" + " in the following Logprob" + f" dict: {pos_logprob_dict}") + + # Validate prompt logprob detokenization + for plp_tok in pos_logprob_dict: + # Confirm that prompt logprob decoded token matches + # the logprob token id at this sequence position + decoded_token = pos_logprob_dict[plp_tok].decoded_token + ref_decoded_token = _ref_convert_id_to_token( + dtv.tokenizer, plp_tok) + assert decoded_token == ref_decoded_token, ( + f"Prompt logprob token id {plp_tok} decodes to" + f" {ref_decoded_token} but Logprob decoded" + f" token is {decoded_token} instead" + f" (at position {idx})") + else: + # Prompt logprobs disabled for this request + assert prompt_logprobs is None + + +@pytest.mark.parametrize( + "request_output_kind", + [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.parametrize("num_sample_logprobs", + [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) +@pytest.mark.parametrize("num_prompt_logprobs", + [None, NUM_PROMPT_LOGPROBS_UNDER_TEST]) +def test_logprobs_processor(request_output_kind: RequestOutputKind, + num_sample_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], + dummy_test_vectors): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + log_stats=False) + engine_core = MockEngineCore( + tokens_list=dummy_test_vectors.generation_tokens, + generated_logprobs_raw=None if num_sample_logprobs is None else + dummy_test_vectors.generation_logprobs, + prompt_logprobs_raw=None + if num_prompt_logprobs is None else dummy_test_vectors.prompt_logprobs) + + # Make N requests. + request_id_list = [ + f"request-{idx}" + for idx in range(len(dummy_test_vectors.prompt_strings)) + ] + requests = [ + EngineCoreRequest(request_id=request_id_list[idx], + prompt=prompt, + prompt_token_ids=prompt_tokens, + arrival_time=0, + mm_inputs=None, + mm_hashes=None, + mm_placeholders=None, + eos_token_id=None, + lora_request=None, + sampling_params=SamplingParams( + skip_special_tokens=False, + spaces_between_special_tokens=False, + output_kind=request_output_kind, + stop=[], + include_stop_str_in_output=False, + logprobs=num_sample_logprobs, + prompt_logprobs=num_prompt_logprobs, + )) for idx, (prompt, prompt_tokens) in enumerate( + zip(dummy_test_vectors.prompt_strings, + dummy_test_vectors.prompt_tokens)) + ] + + # Add requests to the detokenizer. + for request in requests: + output_processor.add_request(request) + + gen_tokens = {} + gen_logprobs = {} + gen_prompt_logprobs = {} + gen_cumulative_logprobs = {} + while True: + # Mock output from the EngineCore. + outputs = engine_core.get_outputs() + if len(outputs) == 0: + break + + # Step the logprobs processor. + processed_outputs = output_processor.process_outputs(outputs) + request_outputs = processed_outputs.request_outputs + requests_to_abort = processed_outputs.reqs_to_abort + assert len(requests_to_abort) == 0 + + # Update tracking. + for request_output in request_outputs: + request_id = request_output.request_id + new_tokens = request_output.outputs[0].token_ids + prompt_logprobs = request_output.prompt_logprobs + logprobs = request_output.outputs[0].logprobs + gen_cumulative_logprobs[request_id] = request_output.outputs[ + 0].cumulative_logprob + if request_id not in gen_logprobs: + # Start tracking sample and prompt logprobs for this request + gen_tokens[request_id] = new_tokens + gen_logprobs[request_id] = logprobs + gen_prompt_logprobs[request_id] = prompt_logprobs + else: + # Extend logprobs tracker + gen_tokens[request_id].extend(new_tokens) + lp = gen_logprobs[request_id] + plp = gen_prompt_logprobs[request_id] + if lp: + lp.extend(logprobs) + if plp: + plp.extend(prompt_logprobs) + + # Confirmed tracked logprobs match what we expect + _validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs, + gen_cumulative_logprobs, dummy_test_vectors, + request_id_list, num_sample_logprobs, + num_prompt_logprobs) + + assert output_processor.get_num_unfinished_requests() == 0 + assert not output_processor.has_unfinished_requests() + + @pytest.mark.parametrize("include_stop_str_in_output", [True, False]) -def test_stop_string(include_stop_str_in_output: bool): - output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=False) - engine_core = MockEngineCore(GENERATION_TOKENS) +@pytest.mark.parametrize("num_sample_logprobs", + [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) +@pytest.mark.parametrize("num_prompt_logprobs", + [None, NUM_PROMPT_LOGPROBS_UNDER_TEST]) +def test_stop_string(include_stop_str_in_output: bool, + num_sample_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], dummy_test_vectors): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + log_stats=False) + engine_core = MockEngineCore( + tokens_list=dummy_test_vectors.generation_tokens, + generated_logprobs_raw=dummy_test_vectors.generation_logprobs + if num_sample_logprobs else None, + prompt_logprobs_raw=dummy_test_vectors.prompt_logprobs + if num_prompt_logprobs else None) # Make N requests. + request_id_list = [ + f"request-{idx}" + for idx in range(len(dummy_test_vectors.prompt_strings)) + ] requests = [ EngineCoreRequest( - request_id=f"request-{idx}", + request_id=request_id_list[idx], prompt=prompt, prompt_token_ids=prompt_tokens, arrival_time=0, @@ -166,9 +507,11 @@ def test_stop_string(include_stop_str_in_output: bool): output_kind=RequestOutputKind.DELTA, stop=STOP_STRINGS, include_stop_str_in_output=include_stop_str_in_output, - )) for idx, ( - prompt, - prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) + logprobs=num_sample_logprobs, + prompt_logprobs=num_prompt_logprobs, + )) for idx, (prompt, prompt_tokens) in enumerate( + zip(dummy_test_vectors.prompt_strings, + dummy_test_vectors.prompt_tokens)) ] # Add requests to the detokenizer. @@ -176,6 +519,10 @@ def test_stop_string(include_stop_str_in_output: bool): output_processor.add_request(request) gen_strings = {} + gen_tokens = {} + gen_logprobs = {} + gen_prompt_logprobs = {} + gen_cumulative_logprobs = {} aborted = [] while True: # Mock output from the EngineCore. @@ -199,14 +546,29 @@ def test_stop_string(include_stop_str_in_output: bool): request_id = request_output.request_id new_text = request_output.outputs[0].text + new_tokens = request_output.outputs[0].token_ids + prompt_logprobs = request_output.prompt_logprobs + logprobs = request_output.outputs[0].logprobs + gen_cumulative_logprobs[request_id] = request_output.outputs[ + 0].cumulative_logprob if request_id not in gen_strings: gen_strings[request_id] = new_text + gen_tokens[request_id] = new_tokens + gen_logprobs[request_id] = logprobs + gen_prompt_logprobs[request_id] = prompt_logprobs else: gen_strings[request_id] += new_text + gen_tokens[request_id].extend(new_tokens) + lp = gen_logprobs[request_id] + plp = gen_prompt_logprobs[request_id] + if lp: + lp.extend(logprobs) + if plp: + plp.extend(prompt_logprobs) # Confirmed tracked values matches what we expected. - for idx, (ref_gen_str, - stop_str) in enumerate(zip(GENERATION_STRINGS, STOP_STRINGS)): + for idx, (ref_gen_str, stop_str) in enumerate( + zip(dummy_test_vectors.generation_strings, STOP_STRINGS)): # Request should be aborted. request_id = f"request-{idx}" @@ -227,13 +589,20 @@ def test_stop_string(include_stop_str_in_output: bool): assert gen_str == ref_str_exc_stop, ( f"{gen_str=}, {ref_str_exc_stop=}") + # Confirmed tracked logprobs match what we expect + _validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs, + gen_cumulative_logprobs, dummy_test_vectors, + request_id_list, num_sample_logprobs, + num_prompt_logprobs) + assert output_processor.get_num_unfinished_requests() == 0 assert not output_processor.has_unfinished_requests() -def test_iteration_stats(): - output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=True) - engine_core = MockEngineCore(GENERATION_TOKENS) +def test_iteration_stats(dummy_test_vectors): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + log_stats=True) + engine_core = MockEngineCore(dummy_test_vectors.generation_tokens) # Make N requests. requests = [ @@ -248,13 +617,13 @@ def test_iteration_stats(): eos_token_id=None, lora_request=None, sampling_params=SamplingParams(), - ) for idx, ( - prompt, - prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) + ) for idx, (prompt, prompt_tokens) in enumerate( + zip(dummy_test_vectors.prompt_strings, + dummy_test_vectors.prompt_tokens)) ] # Add all requests except one to the OutputProcessor. - num_active = len(GENERATION_TOKENS) - 1 + num_active = len(dummy_test_vectors.generation_tokens) - 1 for request in requests[:num_active]: output_processor.add_request(request) inactive_request = requests[num_active] @@ -263,8 +632,10 @@ def test_iteration_stats(): outputs = engine_core.get_outputs()[:num_active] processed_outputs = output_processor.process_outputs(outputs) iteration_stats = processed_outputs.iteration_stats - total_prompt_tokens = sum( - [len(prompt_tokens) for prompt_tokens in PROMPT_TOKENS[:num_active]]) + total_prompt_tokens = sum([ + len(prompt_tokens) + for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active] + ]) assert iteration_stats.num_prompt_tokens == total_prompt_tokens assert iteration_stats.num_generation_tokens == num_active @@ -283,7 +654,7 @@ def test_iteration_stats(): outputs = engine_core.get_outputs()[:num_active] processed_outputs = output_processor.process_outputs(outputs) iteration_stats = processed_outputs.iteration_stats - total_prompt_tokens = len(PROMPT_TOKENS[num_active - 1]) + total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1]) assert iteration_stats.num_prompt_tokens == total_prompt_tokens assert iteration_stats.num_generation_tokens == num_active diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py new file mode 100644 index 0000000000000..39248ce86f25a --- /dev/null +++ b/tests/v1/engine/utils.py @@ -0,0 +1,382 @@ +# SPDX-License-Identifier: Apache-2.0 + +import random +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +from vllm.engine.arg_utils import EngineArgs +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) +from vllm.v1.engine import EngineCoreOutput, FinishReason +from vllm.v1.outputs import LogprobsLists, LogprobsTensors + +GeneralTokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + +# Number of sample logprobs to request when testing sample logprobs +NUM_SAMPLE_LOGPROBS_UNDER_TEST = 5 +# Number of prompt logprobs to request when testing prompt logprobs +NUM_PROMPT_LOGPROBS_UNDER_TEST = 7 + +TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3" + +FULL_STRINGS = [ + "My name is Robert from Neural Magic and I love working on vLLM so much!", + "Red Hat is the best open source company by far across Linux, K8s, and AI.", + "Nick is the name of my brother in addition to my colleague from Red Hat.", +] +STOP_STRINGS = ["I love working on", "company by far", "brother in"] +PROMPT_LEN = 5 + +PLP_APC_UNSUPPORTED_MSG = ("Prefix caching with prompt logprobs not yet " + "supported on VLLM V1.") + +random.seed(42) + + +def _create_random_top_logprob_test_vector( + num_logprobs: int, + lower: float, + upper: float, +) -> torch.Tensor: + """Create a random vector of top logprob float values. + + Use to create fake sample logprobs for testing. + + Note that a real production scenario would require + logprobs to be sorted in descending order, something + which is omitted in this function. + + Args: + num_logprobs: number of top logprobs + lower: lower range of logprob float values + upper: upper range of logprob float values + + Returns: + 1D length-`num_logprobs` torch Tensor of float logprob values + """ + return torch.rand(num_logprobs) * (upper - lower) + lower + + +def _create_random_top_logprob_test_matrix( + shape: Tuple, + lower: float, + upper: float, +) -> torch.Tensor: + """Create a random matrix of top logprob float values. + + Use to create fake prompt logprobs for testing. + + Note that a real production scenario would require + logprobs to be sorted in descending order along rows, + something which is omitted in this function. + + Args: + shape: (num_tokens,num_logprobs) tuple representing + matrix shape + lower: lower range of logprob float values + upper: upper range of logprob float values + + Returns: + 2D num_tokens x num_logprobs torch Tensor of float logprob values + """ + return torch.rand(*shape) * (upper - lower) + lower + + +def _create_random_top_token_test_vector( + num_logprobs: int, + lower: int, + upper: int, + sampled_token_id: int, + adjust_num_logprobs: bool = True) -> Tuple[torch.Tensor, int]: + """Create a random vector of top logprob token indices + + Use to create fake sample logprobs for testing. The sampled token + ID must always be one of the top logprobs, which this dummy test + vector generator enforces. OpenAI API + compatible engines must be able to return an additional sample + logprob for the sampled token if the sampled token was not + among the top sample logprobs; `adjust_num_logprobs` emulates + this behavior by increasing the vector length by 1 if + `adjust_num_logprobs` is set. + + Args: + num_logprobs: number of top logprobs + lower: lower range of token ids + upper: upper range of token ids + sampled_token_id: the token actually sampled + adjust_num_logprobs: if True, emulate situation where sampled + token logprob must be injected into top + logprobs + + Returns: + 1D length-x torch Tensor of token ids where x is + `num_logprobs+1` if `adjust_num_logprobs` and + `num_logprobs` otherwise + sampled_token_rank: the rank of sampled_token_id in the vocab + vector when sorted in descending order by + logprob + """ + + # Calculate the final number of logprobs required + total_logprobs = num_logprobs + 1 if adjust_num_logprobs else num_logprobs + + # Generate random indices using torch + choice_tensor = torch.randperm(upper - lower)[:total_logprobs] + lower + + # Ensure the sampled token ID is included in the tensor + choice_tensor[0] = sampled_token_id + + # Check if the sampled_token_id occurs in choice_tensor[1:] + if sampled_token_id in choice_tensor[1:]: + sampled_token_rank = (choice_tensor[1:] == sampled_token_id).nonzero( + as_tuple=True)[0].item() + else: + # If not found, assign a random int between num_logprobs and 50700 + sampled_token_rank = random.randint(num_logprobs, 50700) + + return choice_tensor, sampled_token_rank + + +def _create_random_top_token_test_matrix( + shape: Tuple[int, int], + lower: int, + upper: int, + tokens_list: List[int], +) -> Tuple[torch.Tensor, torch.Tensor]: + """Create a random matrix of top logprob token indices + + Use to create fake prompt logprobs for testing. + + Token ids are generated randomly and sampled without + replacement. + + Args: + shape: (num_tokens, num_logprobs) tuple representing + matrix shape + lower: lower range of token ids + upper: upper range of token ids + + Returns: + Tuple containing: + - 2D num_tokens x num_logprobs+1 torch Tensor of token ids + - 1D tensor of ranks of prompt tokens in their respective + rows, or random values + """ + num_elements = shape[0] * shape[1] + choice_tensor = torch.randperm(upper - lower)[:num_elements] + lower + matrix = torch.cat( + (torch.tensor(tokens_list, dtype=torch.int).unsqueeze(-1), + choice_tensor.view(shape)), + dim=1) + + # Initialize the tensor for storing the ranks + prompt_token_ranks = torch.empty(shape[0], dtype=torch.int) + + # Iterate over each row to check presence of + # tokens_list[rdx] and determine its index + for rdx in range(shape[0]): + row = matrix[rdx, + 1:] # Skip the first column as it contains the token list + token_index = (row == tokens_list[rdx]).nonzero(as_tuple=True)[0] + if token_index.numel() > 0: + prompt_token_ranks[rdx] = token_index.item() + else: + prompt_token_ranks[rdx] = random.randint(shape[1], 50700) + + return matrix, prompt_token_ranks + + +def decode_token( + tok_id: int, + tokenizer: PreTrainedTokenizer, +) -> str: + """Reproduce the process of detokenizing a token for testing purposes. + + Args: + tok_id: token id to detokenize + tokenizer: tokenizer to use for detokenization + + Returns: + string representation of token + """ + return tokenizer.convert_ids_to_tokens(tok_id) + + +def generate_dummy_sample_logprobs( + sampled_tokens_list: List, + num_logprobs: int, + tokenizer: PreTrainedTokenizer, +) -> List[Tuple[List[int], List[float], int]]: + """Generate dummy sample logprobs + + Generate a test data structure which imitates the list of sample logprobs + which would be assembled in the engine core during decode phase. + + Args: + sampled_tokens_list: list of sampled tokens + num_logprobs: return `num_logprobs` or `num_logprobs+1` logprobs per token + tokenizer: model tokenizer to use for detokenization + + Returns + List of (top token ids vector, logprobs vector, sampled token rank) + Python lists tuples; in each tuple the logprobs and top token ids + vectors have the same length which is either `num_logprobs` or + `num_logprobs+1`. Sampled token rank is the rank (index+1) of the + sampled token within the vocab vector when sorted by logprob in + descending order. + """ + res = [] + for sampled_token_id in sampled_tokens_list: + ( + token_vector, + sampled_token_rank, + ) = _create_random_top_token_test_vector(num_logprobs, 0, + len(tokenizer.vocab) - 1, + sampled_token_id) + + res.append( + (token_vector, + _create_random_top_logprob_test_vector(num_logprobs + 1, -100, + 0), sampled_token_rank)) + + # Convert tensors in the list tuples to Python lists + res_list_format = [ + (log_probs_tensor.tolist(), token_ids_tensor.tolist(), + sampled_token_rank) + for log_probs_tensor, token_ids_tensor, sampled_token_rank in res + ] + + return res_list_format + + +def generate_dummy_prompt_logprobs_tensors( + prompt_tokens_list: List, + num_logprobs: int, + tokenizer: PreTrainedTokenizer, +) -> LogprobsTensors: + """Generate dummy prompt logprobs tensors + + Generate a test data structure which imitates the torch Tensors of prompt + logprobs which would be assembled in the engine core during chunked + prefill. + + Args: + prompt_tokens_list: list of prompt tokens + num_logprobs: return `num_logprobs` logprobs per token + tokenizer: model tokenizer to use for detokenization + + Returns + Single Tuple of (logprobs matrix, top token ids matrix) torch Tensor, + where both matrices have dimensions + num_prompt_tokens x num_logprobs + """ + # For now, assume the whole prompt is processed in one chunk; thus, + # the number of non-`None` prompt logprobs is `len(prompt_tokens_list)-1`. + # Prior to injecting `None` at the beginning of prompt logprobs (which + # happens later in the detokenizer, not here), the prompt logprobs in + # the ith position are predicting the probability distribution of the + # prompt token in (i+1)st position. Thus, we concat + # `prompt_tokens_list[1:]` to the dummy token ids, just as the engine + # would. + num_prompt_logprobs = len(prompt_tokens_list) - 1 + ( + token_vector, + prompt_token_ranks, + ) = _create_random_top_token_test_matrix( + (num_prompt_logprobs, num_logprobs), 0, + len(tokenizer.vocab) - 1, prompt_tokens_list[1:]) + return LogprobsTensors( + token_vector, + _create_random_top_logprob_test_matrix( + (num_prompt_logprobs, num_logprobs + 1), -100, 0), + prompt_token_ranks) + + +@dataclass +class DummyOutputProcessorTestVectors: + """Dummy test vectors for output processor tests""" + tokenizer: GeneralTokenizerType + tokenizer_group: BaseTokenizerGroup + vllm_config: EngineArgs + full_tokens: List[List[int]] # Prompt + generated tokens + prompt_tokens: List[List[int]] + generation_tokens: List[List[int]] + # Each request is associated with a tuple of + # (top tokens, top logprobs, ranks) prompt logprobs tensors + prompt_logprobs: List[LogprobsTensors] + # Each request is associated with a sample logprobs; a request's + # sample logprobs are a list of (top tokens, top logprobs, ranks) + # sample logprobs tensors at each sequence position + generation_logprobs: List[List[Tuple[List[int], List[float], int]]] + prompt_strings: List[str] + prompt_strings_len: List[int] + generation_strings: List[str] + + +class MockEngineCore: + """Mock engine core outputs form premade tokens lists.""" + + def __init__( + self, + tokens_list: List[List[int]], + # For each request, for each sampled token offset, + # a tuple of + # (list of topk token ids, list of sample logprob vals, rank) + generated_logprobs_raw: Optional[List[List[Tuple[List[int], + List[float], + int]]]] = None, + # For each request, a tuple of + # (prompt logprob val matrix, prompt logprob tok id matrix); + # each matrix has dimensions + # (num prompt toks) x (num prompt logprobs+1) + prompt_logprobs_raw: Optional[List[LogprobsTensors]] = None, + ) -> None: + self.tokens_list = tokens_list + self.current_idx = 0 + self.generated_logprobs_raw = generated_logprobs_raw + self.do_logprobs = generated_logprobs_raw is not None + self.prompt_logprobs_raw = prompt_logprobs_raw + self.do_prompt_logprobs = prompt_logprobs_raw is not None + + def get_outputs(self) -> List[EngineCoreOutput]: + do_logprobs = self.do_logprobs + do_prompt_logprobs = self.do_prompt_logprobs + token_idx = self.current_idx + + outputs = [] + for req_idx, token_ids in enumerate(self.tokens_list): + if len(token_ids) > token_idx: + if do_logprobs: + assert self.generated_logprobs_raw is not None + (logprobs_token_ids_, logprobs_, sampled_token_ranks_) = ( + self.generated_logprobs_raw[req_idx][token_idx]) + logprobs = LogprobsLists( + [logprobs_token_ids_], + [logprobs_], + [sampled_token_ranks_], + ) + else: + logprobs = None + if do_prompt_logprobs: + if self.current_idx == 0: + assert self.prompt_logprobs_raw is not None + prompt_logprobs = self.prompt_logprobs_raw[req_idx] + else: + prompt_logprobs = None + else: + prompt_logprobs = None + output = EngineCoreOutput( + request_id=f"request-{req_idx}", + new_token_ids=[token_ids[token_idx]], + new_logprobs=logprobs, + new_prompt_logprobs_tensors=prompt_logprobs, + ) + if token_idx == len(token_ids) - 1: + output.finish_reason = FinishReason.STOP + outputs.append(output) + + self.current_idx += 1 + return outputs diff --git a/tests/v1/entrypoints/__init__.py b/tests/v1/entrypoints/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py new file mode 100644 index 0000000000000..b00e168db9d32 --- /dev/null +++ b/tests/v1/entrypoints/conftest.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + + +@pytest.fixture +def sample_prompts(): + return [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + +@pytest.fixture +def sample_token_ids(): + return [ + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], + ] + + +@pytest.fixture +def sample_regex(): + return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") + + +@pytest.fixture +def sample_json_schema(): + return { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work_history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "number" + }, + "position": { + "type": "string" + } + }, + "required": ["company", "position"] + } + } + }, + "required": ["name", "age", "skills", "work_history"] + } + + +@pytest.fixture +def sample_complex_json_schema(): + return { + "type": "object", + "properties": { + "score": { + "type": "integer", + "minimum": 0, + "maximum": 100 # Numeric range + }, + "grade": { + "type": "string", + "pattern": "^[A-D]$" # Regex pattern + }, + "email": { + "type": "string", + "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$" + }, + "tags": { + "type": "array", + "items": { + "type": "string", + "pattern": + "^[a-z]{1,10}$" # Combining length and pattern restrictions + } + } + }, + "required": ["score", "grade", "email", "tags"] + } + + +@pytest.fixture +def sample_definition_json_schema(): + return { + '$defs': { + 'Step': { + 'properties': { + 'explanation': { + 'title': 'Explanation', + 'type': 'string' + }, + 'output': { + 'title': 'Output', + 'type': 'string' + } + }, + 'required': ['explanation', 'output'], + 'title': 'Step', + 'type': 'object' + } + }, + 'properties': { + 'steps': { + 'items': { + '$ref': '#/$defs/Step' + }, + 'title': 'Steps', + 'type': 'array' + }, + 'final_answer': { + 'title': 'Final Answer', + 'type': 'string' + } + }, + 'required': ['steps', 'final_answer'], + 'title': 'MathReasoning', + 'type': 'object' + } + + +@pytest.fixture +def sample_guided_choice(): + return [ + "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", + "Ruby", "Swift", "Kotlin" + ] + + +@pytest.fixture +def sample_sql_statements(): + return (""" +start: select_statement +select_statement: "SELECT" column "from" table "where" condition +column: "col_1" | "col_2" +table: "table_1" | "table_2" +condition: column "=" number +number: "1" | "2" +""") diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py new file mode 100644 index 0000000000000..ef46a16ef3447 --- /dev/null +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -0,0 +1,475 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import Dict, List, Optional + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio +from openai import BadRequestError + +from tests.utils import RemoteOpenAIServer +from vllm.transformers_utils.tokenizer import get_tokenizer + +# any model with a chat template should work here +MODEL_NAME = "facebook/opt-125m" + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager" + ] + + +@pytest.fixture(scope="module", + params=[["--no-enable-prefix-caching"], + [ + "--no-enable-prefix-caching", + "--disable-frontend-multiprocessing" + ]]) +def server(default_server_args, request): + if request.param: + default_server_args.extend(request.param) + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_single_completion(client: openai.AsyncOpenAI, + model_name: str) -> None: + completion = await client.completions.create(model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=6, total_tokens=11) + + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 1 + assert completion.choices[0].prompt_logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=None, + ) + choice = completion.choices[0] + assert choice.logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=0, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=5, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, + model_name: str) -> None: + + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=21, + ) + ... + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + stream = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=30, + stream=True, + ) + async for chunk in stream: + ... + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1), + (MODEL_NAME, 0), + (MODEL_NAME, 1), + (MODEL_NAME, None)]) +async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, + model_name: str, + prompt_logprobs: Optional[int]): + params: Dict = { + "prompt": ["A robot may not injure another robot", "My name is"], + "model": model_name, + } + if prompt_logprobs is not None: + params["extra_body"] = {"prompt_logprobs": prompt_logprobs} + + if prompt_logprobs is not None and prompt_logprobs < 0: + with pytest.raises(BadRequestError): + await client.completions.create(**params) + else: + completion = await client.completions.create(**params) + if prompt_logprobs is not None: + assert completion.choices[0].prompt_logprobs is not None + assert len(completion.choices[0].prompt_logprobs) > 0 + + assert completion.choices[1].prompt_logprobs is not None + assert len(completion.choices[1].prompt_logprobs) > 0 + + else: + assert completion.choices[0].prompt_logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_completion_streaming(client: openai.AsyncOpenAI, + model_name: str) -> None: + prompt = "What is an LLM?" + + single_completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) + chunks: List[str] = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_completion_stream_options(client: openai.AsyncOpenAI, + model_name: str): + prompt = "What is the capital of France?" + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + False, + }) + + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + False, + }) + async for chunk in stream: + if chunk.choices[0].finish_reason is None: + assert chunk.usage is None + else: + assert chunk.usage is None + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is not None + assert chunk.usage.prompt_tokens > 0 + assert chunk.usage.completion_tokens > 0 + assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + + chunk.usage.completion_tokens) + if chunk.choices[0].finish_reason is not None: + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=False, stream_options= + # {"include_usage": None} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": None}) + + # Test stream=False, stream_options= + # {"include_usage": True} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": True}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": None} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": None}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": True} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": True}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): + # test both text and token IDs + for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2): + # test simple list + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + ) + assert len(batch.choices) == 2 + assert batch.choices[0].text == batch.choices[1].text + + # test n = 2 + batch = await client.completions.create( + model=model_name, + prompt=prompts, + n=2, + max_tokens=5, + temperature=0.0, + extra_body=dict( + # NOTE: this has to be true for n > 1 in vLLM, but + # not necessary for official client. + use_beam_search=True), + ) + assert len(batch.choices) == 4 + assert batch.choices[0].text != batch.choices[ + 1].text, "beam search should be different" + assert batch.choices[0].text == batch.choices[ + 2].text, "two copies of the same prompt should be the same" + assert batch.choices[1].text == batch.choices[ + 3].text, "two copies of the same prompt should be the same" + + # test streaming + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + stream=True, + ) + texts = [""] * 2 + async for chunk in batch: + assert len(chunk.choices) == 1 + choice = chunk.choices[0] + texts[choice.index] += choice.text + assert texts[0] == texts[1] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +@pytest.mark.parametrize("logprobs_arg", [1, 0]) +async def test_echo_logprob_completion(client: openai.AsyncOpenAI, + model_name: str, logprobs_arg: int): + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + # test using text and token IDs + for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=logprobs_arg) + + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, + list) else prompt + assert re.search(r"^" + prompt_text, completion.choices[0].text) + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) > 5 + assert (len(logprobs.token_logprobs) > 5 + and logprobs.token_logprobs[0] is None) + assert (len(logprobs.top_logprobs) > 5 + and logprobs.top_logprobs[0] is None) + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) > 5 diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py new file mode 100644 index 0000000000000..86c576cd70a57 --- /dev/null +++ b/tests/v1/sample/test_logprobs.py @@ -0,0 +1,392 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from typing import List, Tuple + +import pytest +import torch + +from tests.kernels.utils import override_backend_env_variable +from tests.v1.sample.utils import ( + assert_incr_detok_str_matches_non_incr_detok_str, + compute_correct_cumulative_logprob, get_test_batch) +from vllm import SamplingParams + +from ...conftest import VllmRunner + +MODEL = "meta-llama/Llama-3.2-1B" +DTYPE = "half" + + +@pytest.fixture(scope="module") +def vllm_model(vllm_runner): + with vllm_runner( + MODEL, + dtype=DTYPE, + max_logprobs=7, + # Very small number of batched tokens to ensure + # that we test chunking. + max_num_batched_tokens=16, + max_num_seqs=16, + max_model_len=128, + enforce_eager=True, + #TODO: enable this once we support it for + # prompt logprobs. + enable_prefix_caching=False, + gpu_memory_utilization=0.5, + ) as vllm_model: + yield vllm_model + + +@pytest.fixture(scope="module") +def hf_model(hf_runner): + with hf_runner(MODEL, dtype=DTYPE) as hf_model: + yield hf_model + + +def _repeat_logprob_config( + test_prompts, + logprob_prompt_logprob_list: List[Tuple], +) -> List[Tuple]: + """Ensure each test prompt has a logprob config. + + A logprob config specifies the optional (i.e. + may-be-`None`) number of sample logprobs and + the optional number of prompt logprobs. + + If more test prompts than logprob configs are + provided, the provided logprob configs are + tiled to match the number of test prompts. + + If fewer test prompts than logprob configs + are provided, the list of logprob configs + is truncated to match the number of test + prompts. + + Otherwise, the list of logprob configs + is returned as-is. + + Args: + test_prompts: list of prompts under test + logprob_prompt_logprob_list: list of + (optional num sample logprob, + optional num prompt logprob) + tuples + + Returns: + List of + (optional num sample logprob,optional num prompt logprob) + tuples which is either identical to + `logprob_prompt_logprob_list`, or else repeats + `logprob_prompt_logprob_list` enough times to match the + number of `test_prompts`, or else is truncated to match + the number of `test_prompts` + """ + num_test_prompts = len(test_prompts) + # Make sure there is a logprobs configuration for each test prompt + logprob_prompt_logprob_list = list( + itertools.islice(itertools.cycle(logprob_prompt_logprob_list), + num_test_prompts)) + # Now the number of prompts should match the number of sample params combos + assert num_test_prompts == len(logprob_prompt_logprob_list) + return logprob_prompt_logprob_list + + +def _test_case_get_logprobs_and_prompt_logprobs( + hf_model, + vllm_model, + batch_logprobs_composition: str, + temperature: float, + example_prompts, +) -> None: + test_prompts = example_prompts + + max_tokens = 5 + hf_outputs = hf_model.generate_greedy( + test_prompts, + max_tokens=max_tokens, + ) + hf_logprobs = hf_model.generate_greedy_logprobs( + test_prompts, + max_tokens=max_tokens, + ) + + # Batch has mixed sample params + # (different logprobs/prompt logprobs combos) + logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition) + + # Ensure that each test prompt has a logprob config for testing + logprob_prompt_logprob_list = _repeat_logprob_config( + test_prompts, logprob_prompt_logprob_list) + # Generate SamplingParams + vllm_sampling_params = [ + SamplingParams(max_tokens=max_tokens, + logprobs=num_lp, + prompt_logprobs=num_plp, + temperature=temperature, + seed=1984) + for num_lp, num_plp in logprob_prompt_logprob_list + ] + + vllm_results = vllm_model.model.generate( + test_prompts, sampling_params=vllm_sampling_params) + + for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip( + vllm_results, hf_logprobs, hf_outputs, + logprob_prompt_logprob_list): + + # Extract request-level (prompt)logprobs config + num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob + + # Test whether sampled token output is consistent between vLLM and HF + # vLLM prompt+completion should match HF output + if temperature == 0.0: + assert (vllm_result.prompt_token_ids + + vllm_result.outputs[0].token_ids == hf_output[0]) + else: + # Sampled tokens won't match if not greedy + assert (vllm_result.prompt_token_ids == hf_output[0] + [:len(vllm_result.prompt_token_ids)]) + + # Validate sample logprobs + if num_top_logprobs is not None: + assert num_top_logprobs is not None + # Confirm that the structure of the sample logprobs in the result is + # correct + assert vllm_result.outputs[0].logprobs is not None + assert len(vllm_result.outputs[0].logprobs) == max_tokens + for logprobs, token_id in zip(vllm_result.outputs[0].logprobs, + vllm_result.outputs[0].token_ids): + assert logprobs is not None + + # Confirm that the output token appears among the logprobs + assert token_id in logprobs + token_in_topk = logprobs[token_id].rank <= num_top_logprobs + + # If the output token is not included in the top K + # logprob, it can return 1 more data + if token_in_topk and num_top_logprobs != 0: + assert len(logprobs) == num_top_logprobs + else: + assert len(logprobs) == num_top_logprobs + 1 + + if num_top_logprobs > 0: + # We should have an entry for each of the topk ranks + all_ranks = {lp.rank for lp in logprobs.values()} + assert all(r in all_ranks + for r in range(1, num_top_logprobs + 1)) + + output_text = vllm_result.outputs[0].text + output_string_from_most_likely_tokens_lst: List[str] = [] + for top_logprobs in vllm_result.outputs[0].logprobs: + top_logprob = next(iter(top_logprobs.values())) + output_string_from_most_likely_tokens_lst.append( + top_logprob.decoded_token) + + output_string_from_most_likely_tokens = "".join( + output_string_from_most_likely_tokens_lst) + assert_incr_detok_str_matches_non_incr_detok_str( + output_text, output_string_from_most_likely_tokens, + "The output text from the top logprob for each token " + "position should be the same as the output text in the " + "result.") + + # Compare vLLM sample logprobs to HF + vllm_sample_logprobs = vllm_result.outputs[0].logprobs + for i, top_logprobs in enumerate(vllm_sample_logprobs): + for token_id, sample_logprob in top_logprobs.items(): + if temperature == 0.0 or i == 0: + logprob = sample_logprob.logprob + torch.testing.assert_close( + logprob, + hf_logprob[i][-1][token_id].item(), + atol=1e-2, + rtol=1e-2) + assert isinstance( + sample_logprob.decoded_token, + str), ("The token should be decoded by the time it is" + " returned to the user.") + + # At this point we know the sample logprobs are correct for this + # request. Validate that cumulative_logprob is actually the sum. + # For each request, assert that the returned cumulative logprob + # matches the correct value, which is computed below. + torch.testing.assert_close( + vllm_result.outputs[0].cumulative_logprob, + compute_correct_cumulative_logprob(vllm_result.outputs[0]), + atol=1e-6, + rtol=1e-6) + else: + # Logprobs disabled for this request; should be None + assert vllm_result.outputs[0].logprobs is None + + # Validate prompt logprobs + if num_top_prompt_logprobs is not None: + # Confirm that structure of prompt logprobs in result is correct + assert vllm_result.prompt_logprobs is not None + # - The first prompt logprob is always None + assert vllm_result.prompt_logprobs[0] is None + # - Prompt logprobs are returned for all indices in + # the prompt + assert len(vllm_result.prompt_logprobs) == len( + vllm_result.prompt_token_ids) + for prompt_logprobs, prompt_token_id in zip( + vllm_result.prompt_logprobs[1:], + vllm_result.prompt_token_ids[1:]): + assert prompt_logprobs is not None + + # Confirm that the prompt token appears among the logprobs + assert prompt_token_id in prompt_logprobs + token_in_topk = prompt_logprobs[ + prompt_token_id].rank <= num_top_prompt_logprobs + + # If the prompt token is not included in the top K + # logprob, it can return 1 more data + if token_in_topk and num_top_prompt_logprobs != 0: + assert len(prompt_logprobs) == num_top_prompt_logprobs + else: + assert len(prompt_logprobs) == num_top_prompt_logprobs + 1 + + if num_top_prompt_logprobs > 0: + # We should have an entry for each of the topk ranks + all_ranks = {lp.rank for lp in prompt_logprobs.values()} + assert all(r in all_ranks + for r in range(1, num_top_prompt_logprobs + 1)) + + # Compare prompt logprobs to HF + # The first prompt logprob is always None, so we compare it from + # 1:. + vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:] + for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs): + for token_id, logprob in vllm_prompt_logprob_dict.items(): + torch.testing.assert_close( + logprob.logprob, + hf_logprob[0][i][token_id].item(), + atol=2e-2, + rtol=2e-2) + else: + assert vllm_result.prompt_logprobs is None + + +#@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("batch_logprobs_composition", + ["NONE", "SAMPLE", "PROMPT", "SAMPLE_PROMPT"]) +@pytest.mark.parametrize("temperature", [0.0, 2.0]) +def test_get_logprobs_and_prompt_logprobs( + hf_model, + vllm_model, + batch_logprobs_composition: str, + temperature: float, + example_prompts, +) -> None: + """Test V1 Engine logprobs & prompt logprobs + + Exercise a variety of combinations of `logprobs` and `prompt_logprobs` + settings and validate that + * The generated logprobs and prompt logprobs are consistent with the + configuration settings, in terms of whether or not the logprobs + (of either type) were requested and how many were requested + * The generated logprobs are consistent with the generated tokens + * The generated (prompt)logprobs are consistent with HuggingFace + (prompt)logprobs, as a reference + + batch_logprobs_composition controls the logprobs configurations for + requests in the batch under test. + + Args: + hf_model + vllm_model + batch_logprobs_composition: logprobs configuration for test batch + example_prompts + monkeypatch + """ + _test_case_get_logprobs_and_prompt_logprobs( + hf_model=hf_model, + vllm_model=vllm_model, + batch_logprobs_composition=batch_logprobs_composition, + temperature=temperature, + example_prompts=example_prompts) + + +def test_max_logprobs(monkeypatch): + """vLLM v1 engine should fail a request with `logprobs > max_logprobs` + + Should also fail for `prompt_logprobs > max_logprobs` + + Args: + monkeypatch + """ + override_backend_env_variable(monkeypatch, "FLASH_ATTN") + + runner = VllmRunner("facebook/opt-125m", + max_logprobs=1, + enable_prefix_caching=False, + max_model_len=256) + vllm_sampling_params = SamplingParams(logprobs=1) + # should pass + runner.generate(["Hello world"], sampling_params=vllm_sampling_params) + + bad_sampling_params = SamplingParams(logprobs=2) + with pytest.raises(ValueError): + runner.generate(["Hello world"], sampling_params=bad_sampling_params) + + +def test_none_logprobs(vllm_model, example_prompts, monkeypatch): + """Engine should return `logprobs` and `prompt_logprobs` as `None` + + Args: + vllm_model: vLLM model fixture + example_prompts: list of example prompts (test fixture) + monkeypatch: supports editing env vars and rolling back changes + after the test + """ + max_tokens = 5 + + sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens, + logprobs=None, + prompt_logprobs=None, + temperature=0.0) + results_logprobs_none = vllm_model.model.generate( + example_prompts, sampling_params=sampling_params_logprobs_none) + + for i in range(len(results_logprobs_none)): + # Check sample logprobs are None + assert results_logprobs_none[i].outputs[0].logprobs is None + assert results_logprobs_none[i].outputs[0].cumulative_logprob is None + # Check prompt logprobs are None + assert results_logprobs_none[i].prompt_logprobs is None + + +def test_zero_logprobs(vllm_model, example_prompts, monkeypatch): + """Engine should return sampled token and prompt token logprobs + + Args: + vllm_model: vLLM model fixture + example_prompts: list of example prompts (test fixture) + monkeypatch: supports editing env vars and rolling back changes + after the test + """ + max_tokens = 5 + + sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens, + logprobs=0, + prompt_logprobs=0, + temperature=0.0) + results_logprobs_zero = vllm_model.model.generate( + example_prompts, sampling_params=sampling_params_logprobs_zero) + + for i in range(len(results_logprobs_zero)): + # Check that there is one sample logprob dict for each + # sample token + logprobs = results_logprobs_zero[i].outputs[0].logprobs + prompt_logprobs = results_logprobs_zero[i].prompt_logprobs + sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids + prompt_token_ids = results_logprobs_zero[i].prompt_token_ids + assert logprobs is not None + assert len(sampled_token_ids) == len(logprobs) + assert results_logprobs_zero[i].outputs[ + 0].cumulative_logprob is not None + # Check that there is one prompt logprob dict for each + # prompt token + assert prompt_logprobs is not None + assert len(prompt_token_ids) == len(prompt_logprobs) diff --git a/tests/v1/sample/test_logprobs_e2e.py b/tests/v1/sample/test_logprobs_e2e.py new file mode 100644 index 0000000000000..28c177fd497c2 --- /dev/null +++ b/tests/v1/sample/test_logprobs_e2e.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 + +import lm_eval + +from ...utils import RemoteOpenAIServer + +# arc-easy uses prompt_logprobs=1, logprobs=1 +TASK = "arc_easy" +FILTER = "acc_norm,none" +RTOL = 0.03 +EXPECTED_VALUE = 0.62 + +# FIXME(rob): enable prefix caching once supported. +MODEL = "meta-llama/Llama-3.2-1B" +MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False" # noqa: E501 +SERVER_ARGS = [ + "--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests" +] +NUM_CONCURRENT = 100 + + +def test_prompt_logprobs_e2e(): + results = lm_eval.simple_evaluate(model="vllm", + model_args=MODEL_ARGS, + tasks=TASK, + batch_size="auto") + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + + +def test_promt_logprobs_e2e_server(): + with RemoteOpenAIServer(MODEL, SERVER_ARGS) as remote_server: + url = f"{remote_server.url_for('v1')}/completions" + + model_args = ( + f"model={MODEL}," + f"base_url={url}," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py new file mode 100644 index 0000000000000..e1465b1239661 --- /dev/null +++ b/tests/v1/sample/utils.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import List, Tuple + +from vllm import CompletionOutput + + +def get_test_batch(batch_logprobs_composition: str) -> List[Tuple]: + """Generate logprobs configs for a batch of requests + + A given request's logprobs configuration is (1) num_sample_logprobs and (2) + num_prompt_logprobs. The batch logprobs configuration is the list of request + logprobs configs. + + batch_logprobs_composition == "NONE" yields a batch with no sample or prompt + logprobs + + batch_logprobs_composition == "SAMPLE" yields a batch with some requests + configured for sample logprobs only, and others configured for no logprobs + + batch_logprobs_composition == "PROMPT" yields a batch with some requests + configured for prompt logprobs only, and others configured for no logprobs + + batch_logprobs_composition == "SAMPLE_PROMPT" yields a batch with some + requests configured for sample logprobs and prompt logprobs, some configured + for only sample logprobs or only prompt logprobs, and some configured for + no logprobs + + Args: + batch_logprobs_composition: types of logprobs configs to include in batch + + Returns: + + List of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs]) + tuples + """ + if batch_logprobs_composition == "NONE": + # No requests with sample or prompt logprobs + return [(None, None)] + elif batch_logprobs_composition == "SAMPLE": + # Requests requiring sample logprobs or no logprobs + return [ + (None, None), + (0, None), + (5, None), + (3, None), + ] + elif batch_logprobs_composition == "PROMPT": + # Requests requiring prompt logprobs or no logprobs + return [ + (None, None), + (None, 0), + (None, 6), + (None, 5), + ] + elif batch_logprobs_composition == "SAMPLE_PROMPT": + # Requests requiring either no logprobs, just + # sample logprobs, just prompt logprobs, or + # both sample and prompt logprobs + return [ + (None, None), + (0, None), + (5, None), + (3, None), + (0, 3), + (6, 0), + (6, 3), + (None, 6), + (None, 5), + (None, 0), + ] + else: + raise ValueError("Invalid logprobs batch configuration for test.") + + +def assert_incr_detok_str_matches_non_incr_detok_str( + incremental_detokenization_str: str, + non_incremental_detokenization_str: str, + msg: str, +) -> None: + """Compare incrementally detok. text to non-incrementally detok. text + + Fail if the strings mismatch after non-alphanumeric characters are stripped + out. + + Rationale: incremental detokenization in the text generation process allows + the tokenizer to adjust the next token text output based on the token's + context in the string. However, logprobs detokenization detokenizes each + token individually, and the resultant strings may include some + non-alphanumeric placeholder characters where there could be i.e. + whitespace. So, this function compares only the alphanumeric text + between two strings and fails if there is a mismatch, which helps + with validating logprobs detokenization. + + Args: + incremental_detokenization_str: incrementally-detokenized generated text + non_incremental_detokenization_str: non-incrementally-detokenized logprob + tokens + msg: error message if `assert` fails + """ + rgx = r'[^a-zA-Z0-9]+' + assert (re.sub(rgx, '', incremental_detokenization_str) == re.sub( + rgx, '', non_incremental_detokenization_str)), (msg) + + +def compute_correct_cumulative_logprob( + completion_output: CompletionOutput) -> float: + """Compute known-good value for evaluating cumulative logprob + + Args: + completion_output: completion output from engine + + Returns: + Known-good cumulative logprob value + """ + token_ids = completion_output.token_ids + logprobs = completion_output.logprobs + assert logprobs is not None + return sum([lp[tok_id].logprob for tok_id, lp in zip(token_ids, logprobs)]) diff --git a/vllm/__init__.py b/vllm/__init__.py index 566c5116d5f09..457780824c743 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """vLLM: a high-throughput and memory-efficient inference engine for LLMs""" +# The version.py should be independent library, and we always import the +# version library first. Such assumption is critical for some customization. +from .version import __version__, __version_tuple__ # isort:skip + import os import torch @@ -19,8 +23,6 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from .version import __version__, __version_tuple__ - # set some common config/environment variables that should be set # for all processes created by vllm and all processes # that interact with vllm workers. diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 971fe411695cb..5aca10079f9be 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -14,8 +14,8 @@ AttentionMetadataBuilder, AttentionType) from vllm.attention.backends.utils import ( - PAD_SLOT_ID, VLLM_FLASH_ATTN_VERSION, CommonAttentionState, - compute_slot_mapping, compute_slot_mapping_start_idx, + PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, + compute_slot_mapping_start_idx, get_flash_attn_version, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) @@ -640,6 +640,7 @@ def __init__( f"Head size {head_size} is not supported by FlashAttention. " f"Supported head sizes are: {support_head_sizes}.") self.attn_type = attn_type + self.vllm_flash_attn_version = get_flash_attn_version() def forward( self, @@ -759,7 +760,7 @@ def forward( alibi_slopes=alibi_slopes, softcap=logits_soft_cap, out=prefill_output, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) else: # prefix-enabled attention @@ -782,7 +783,7 @@ def forward( block_table=prefill_meta.block_tables, softcap=logits_soft_cap, out=prefill_output, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) if decode_meta := attn_metadata.decode_metadata: @@ -811,7 +812,7 @@ def forward( softcap=logits_soft_cap, block_table=decode_meta.block_tables, out=decode_output, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) else: # Use flash_attn_with_kvcache for normal decoding. @@ -832,7 +833,7 @@ def forward( alibi_slopes=alibi_slopes, softcap=logits_soft_cap, out=decode_output.unsqueeze(1), - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) return output diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index c22f7e92103b8..a41140ec83782 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -12,7 +12,7 @@ from vllm.attention.backends.abstract import (AttentionLayer, AttentionMetadata, MLAAttentionImpl, T) -from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION +from vllm.attention.backends.utils import get_flash_attn_version from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -181,6 +181,7 @@ def __init__( self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj + self.vllm_flash_attn_version = get_flash_attn_version() def _v_up_proj_and_o_proj(self, x): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: @@ -515,7 +516,7 @@ def _forward_prefill_flash( max_seqlen_k=max_prefill_seq_len, softmax_scale=self.scale, causal=True, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) attn_output = attn_output\ .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 9f6e731afd193..f363ba0c1e30c 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -2,6 +2,7 @@ from collections import defaultdict from dataclasses import dataclass +from itertools import accumulate from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type import torch @@ -15,6 +16,7 @@ if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) +from vllm.utils import async_tensor_h2d # Placeholder attention backend for models like Mamba and pooling models that # lack attention. @@ -77,43 +79,39 @@ class PlaceholderAttentionMetadata(AttentionMetadata): # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] - # Maximum query length in the batch. - max_query_len: Optional[int] - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # Maximum query length in the batch. + max_query_len: Optional[int] + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + # Placeholder. + block_tables: Optional[torch.Tensor] = None + _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None @@ -125,11 +123,17 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.seq_start_loc is not None + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) # Placeholders slot_mapping = torch.empty(0) @@ -143,15 +147,15 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_decode_query_len=0, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, ) @@ -169,6 +173,8 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: # Placeholders slot_mapping = torch.empty(0) block_tables = torch.empty(0) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) self._cached_decode_metadata = PlaceholderAttentionMetadata( num_prefills=0, @@ -178,13 +184,16 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, context_lens_tensor=None, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, @@ -235,8 +244,6 @@ def advance_step(self, assert self.context_lens_tensor is not None assert self.context_lens_tensor.shape == (num_queries, ) - assert self.block_tables is not None - # Update query lengths. Note that we update only queries and not seqs, # since tensors may be padded due to captured cuda graph batch size for i in range(num_queries): @@ -299,9 +306,6 @@ def _add_seq_group( self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) @@ -323,15 +327,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 - logits_soft_cap = getattr(self.runner.model_config.hf_config, - "attn_logit_softcapping", None) - if logits_soft_cap is not None: - raise ValueError( - "Please use Flashinfer backend for models with logits_soft_cap" - " (i.e., Gemma-2). Otherwise, the output might be wrong." - " Set Flashinfer backend by " - "export VLLM_ATTENTION_BACKEND=FLASHINFER.") - max_query_len = max(query_lens) decode_query_lens = query_lens[self.num_prefills:] if len(decode_query_lens) > 0: @@ -341,48 +336,37 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) if use_captured_graph: - num_decode_tokens = batch_size - + num_decode_tokens = batch_size - self.num_prefill_tokens assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - context_lens_tensor = torch.tensor(self.context_lens, - dtype=torch.int, - device=device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { modality: placeholder_map.index_map() for modality, placeholder_map in self.multimodal_placeholder_maps.items() } - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) # Placeholders - slot_mapping = torch.empty(0) + slot_mapping_tensor = torch.empty(0) block_tables = torch.empty(0) return PlaceholderAttentionMetadata( num_prefills=self.num_prefills, - slot_mapping=slot_mapping, + slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, @@ -393,8 +377,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index e8a34434122c4..5c1f9916e22c2 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -587,11 +587,11 @@ def get_num_prefill_decode_query_kv_tokens( num_decode_query_tokens) -try: - from vllm.vllm_flash_attn.flash_attn_interface import ( - fa_version_unsupported_reason, is_fa_version_supported) +def get_flash_attn_version(): + try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + fa_version_unsupported_reason, is_fa_version_supported) - def flash_attn_version(): # if hopper default to FA3, otherwise stick to FA2 for now # TODO(lucas): profile FA3 on ampere to see if it makes sense to # use FA3 as default for both @@ -610,7 +610,5 @@ def flash_attn_version(): assert is_fa_version_supported(fa_version) return fa_version - - VLLM_FLASH_ATTN_VERSION = flash_attn_version() -except (ImportError, AssertionError): - VLLM_FLASH_ATTN_VERSION = None + except (ImportError, AssertionError): + return None diff --git a/vllm/config.py b/vllm/config.py index 5579d6936d105..426ba38080270 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1401,6 +1401,9 @@ def __post_init__(self) -> None: logger.info("Defaulting to use %s for distributed inference", backend) + if self.distributed_executor_backend is None and self.world_size == 1: + self.distributed_executor_backend = "uni" + self._verify_args() @property diff --git a/vllm/distributed/kv_transfer/README.md b/vllm/distributed/kv_transfer/README.md index e20c992a381a3..c408d4a67522c 100644 --- a/vllm/distributed/kv_transfer/README.md +++ b/vllm/distributed/kv_transfer/README.md @@ -14,8 +14,8 @@ The KV cache transfer contains three layer of abstractions: Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer. -NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed -communication service already supports key-value-based lookup (like redis or +NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed +communication service already supports key-value-based lookup (like redis or RDMA database). NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates. @@ -27,4 +27,3 @@ The example usage is in [this file](../../../examples/online_serving/disaggregat Here is the diagram of how we run disaggretgated prefilling. ![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg) - diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 5e1b62352d14c..3462f7de020ef 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -10,7 +10,6 @@ stop the prefill instance when the decode instance is slow. """ import threading -import time from collections import deque from typing import Deque, List, Optional, Union @@ -29,13 +28,13 @@ class SimpleBuffer(KVLookupBufferBase): def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: float): """ - signal_pipe: on CPU - - NOTE: on-device recv will block all threads in the process, making the - KV cache producer unable to listen to new request while transmitting - KV cache. Luckily CPU recv only blocks the current thread so we use + signal_pipe: on CPU + + NOTE: on-device recv will block all threads in the process, making the + KV cache producer unable to listen to new request while transmitting + KV cache. Luckily CPU recv only blocks the current thread so we use CPU recv to listen to new request. - + data_pipe: on device (e.g. GPU) """ @@ -43,7 +42,7 @@ def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, self.buffer_size = 0 self.buffer_size_threshold = buffer_size_thresh - self.buffer_lock = threading.Lock() + self.buffer_cv = threading.Condition() self.signal_pipe = signal_pipe self.data_pipe = data_pipe self.request_handling_thread: Optional[threading.Thread] = None @@ -116,11 +115,19 @@ def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, hidden = hidden.clone() buffer_item = [input_tokens, roi, key, value, hidden] + data_size = sum([self._get_element_size(data) for data in buffer_item]) + + with self.buffer_cv: + if self.buffer_size + data_size > self.buffer_size_threshold: + # log outside the while loop to avoid this message being logged + # repeatedly. + logger.debug("KV transfer buffer is full. Handling...") + while self.buffer_size + data_size > self.buffer_size_threshold: + self.buffer_cv.wait() - with self.buffer_lock: - for data in buffer_item: - self.buffer_size += self._get_element_size(data) + self.buffer_size += data_size self.buffer.append(buffer_item) + self.buffer_cv.notify() def _is_end_signal(self, signal): return signal is None @@ -143,35 +150,31 @@ def drop_select_handler(self): roi = (roi > 0.5) tokens_roi_recver = [input_tokens, roi] - matched_length = 0 - - # perform input tokens and roi matching - # FIXME: this matching is O(n), ideally it should be O(1) - # but this buffer size won't (and shouldn't) be too large so - # the fix is not urgent. - with self.buffer_lock: - + def is_buffer_available( + tokens_roi_recver: List[torch.Tensor], ) -> bool: + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. for _ in range(len(self.buffer)): - - temp_length = self._matches(self.buffer[0], - tokens_roi_recver) - if temp_length > 0: - matched_length = temp_length - break + if self._matches(self.buffer[0], + tokens_roi_recver) > 0: + return True # rotate the element we just accessed to the end self.buffer.rotate(-1) - - if matched_length > 0: - # need to clone the tensor - # in case the tensor is freed before sending finishes - matched_item = self.buffer.popleft() - for tensor in matched_item: - self._send_tensor_and_dec_size(tensor) - - else: - # no match, just send None - for _ in range(5): - self.data_pipe.send_tensor(None) + return False + + with self.buffer_cv: + while not is_buffer_available(tokens_roi_recver): + logger.debug( + "KV transfer buffer is not available. Waiting...") + self.buffer_cv.wait() + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + for tensor in matched_item: + self._send_tensor_and_dec_size(tensor) + self.buffer_cv.notify() except RuntimeError as e: if 'Connection closed by peer' not in str(e): @@ -208,20 +211,10 @@ def drop_select( return [input_tokens, roi, key, value, hidden] - def full_handler(self): - time.sleep(0.001) - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: - if self.buffer_size > self.buffer_size_threshold: - # log outside the while loop to avoid this message being logged - # repeatedly. - logger.debug("KV transfer buffer is full. Handling...") - while self.buffer_size > self.buffer_size_threshold: - self.full_handler() - self._add_to_buffer(input_tokens, roi, key, value, hidden) # when calling the insert, the current process is a sender diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 321902d11fd73..bfc41703b94dc 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1024,13 +1024,6 @@ def initialize_model_parallel( backend = backend or torch.distributed.get_backend( get_world_group().device_group) - if (world_size - != tensor_model_parallel_size * pipeline_model_parallel_size): - raise RuntimeError( - f"world_size ({world_size}) is not equal to " - f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " - f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") - # Build the tensor model-parallel groups. num_tensor_model_parallel_groups: int = (world_size // tensor_model_parallel_size) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d82d9ad9df323..2e5bc75c6db38 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -434,6 +434,7 @@ def _initialize_kv_caches(self) -> None: @classmethod def _get_executor_cls(cls, engine_config: VllmConfig) -> Type[ExecutorBase]: + # distributed_executor_backend must be set in VllmConfig.__post_init__ distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) # Initialize the cluster and specify the executor class. @@ -443,30 +444,29 @@ def _get_executor_cls(cls, "distributed_executor_backend must be a subclass of " f"ExecutorBase. Got {distributed_executor_backend}.") executor_class = distributed_executor_backend - elif engine_config.parallel_config.world_size > 1: - if distributed_executor_backend == "ray": - from vllm.executor.ray_distributed_executor import ( - RayDistributedExecutor) - executor_class = RayDistributedExecutor - elif distributed_executor_backend == "mp": - from vllm.executor.mp_distributed_executor import ( - MultiprocessingDistributedExecutor) - assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( - "multiprocessing distributed executor backend does not " - "support VLLM_USE_RAY_SPMD_WORKER=1") - executor_class = MultiprocessingDistributedExecutor - elif distributed_executor_backend == "uni": - # JAX-style, single-process, multi-device executor. - from vllm.executor.uniproc_executor import UniProcExecutor - executor_class = UniProcExecutor - elif distributed_executor_backend == "external_launcher": - # executor with external launcher - from vllm.executor.uniproc_executor import ( # noqa - ExecutorWithExternalLauncher) - executor_class = ExecutorWithExternalLauncher - else: + elif distributed_executor_backend == "ray": + from vllm.executor.ray_distributed_executor import ( + RayDistributedExecutor) + executor_class = RayDistributedExecutor + elif distributed_executor_backend == "mp": + from vllm.executor.mp_distributed_executor import ( + MultiprocessingDistributedExecutor) + assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_RAY_SPMD_WORKER=1") + executor_class = MultiprocessingDistributedExecutor + elif distributed_executor_backend == "uni": + # JAX-style, single-process, multi-device executor. from vllm.executor.uniproc_executor import UniProcExecutor executor_class = UniProcExecutor + elif distributed_executor_backend == "external_launcher": + # executor with external launcher + from vllm.executor.uniproc_executor import ( # noqa + ExecutorWithExternalLauncher) + executor_class = ExecutorWithExternalLauncher + else: + raise ValueError("unrecognized distributed_executor_backend: " + f"{distributed_executor_backend}") return executor_class @classmethod diff --git a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py b/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py index 5c19888d45401..33bba04882be6 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py +++ b/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py @@ -67,6 +67,8 @@ def extract_reasoning_content_streaming( ]): return None + # Check if is present in previous or delta. + # Keep compatibility with models that don't generate tokens. if self.think_start_token_id in previous_token_ids: if self.think_end_token_id in delta_token_ids: # in previous, in delta, @@ -85,7 +87,6 @@ def extract_reasoning_content_streaming( # reasoning content continues return DeltaMessage(reasoning_content=delta_text) elif self.think_start_token_id in delta_token_ids: - logger.info(delta_text) if self.think_end_token_id in delta_token_ids: # in delta, in delta, extract reasoning content start_index = delta_text.find(self.think_start_token) @@ -101,35 +102,46 @@ def extract_reasoning_content_streaming( # reasoning content continues return DeltaMessage(reasoning_content=delta_text) else: - # No in previous or delta, reasoning content continues. - return DeltaMessage(content=delta_text) + # No in previous or delta, also need to check for . + # Because the model may have generated without + # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f + if self.think_end_token_id in delta_token_ids: + # in delta with more tokens, + # extract reasoning content and content + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + elif self.think_end_token_id in previous_token_ids: + # in previous, thinking content ends + return DeltaMessage(content=delta_text) + else: + # no in previous or delta, reasoning content continues + return DeltaMessage(reasoning_content=delta_text) def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest ) -> Tuple[Optional[str], Optional[str]]: - # Check if the model output contains the tokens. - if (self.think_start_token not in model_output - or self.think_end_token not in model_output): + # DeepSeek R1 doesn't generate now. + # Thus we assume the reasoning content is always at the start. + # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f + if self.think_end_token not in model_output: return None, model_output else: + # Add a start token if it's missing to keep compatibility. + if self.think_start_token not in model_output: + model_output = f"{self.think_start_token}{model_output}" # Use a regex to find the reasoning content reasoning_content = self.reasoning_regex.findall(model_output)[0] - # Remove the reasoning content from the model output - # Although deepseek's token is always at the - # beginning of the line, we cannot guarantee that the - # other models will follow this convention. - # Therefore, we need to add :start_index. - start_index = model_output.find(self.think_start_token) - if start_index != -1: - end_index = start_index + len( - f"{self.think_start_token}{reasoning_content}{self.think_end_token}" - ) - model_output = model_output[:start_index] + \ - model_output[end_index:] - - if len(model_output) == 0: - return reasoning_content, None - - return reasoning_content, model_output + end_index = len( + f"{self.think_start_token}{reasoning_content}{self.think_end_token}" + ) + final_output = model_output[end_index:] + + if len(final_output) == 0: + return reasoning_content, None + + return reasoning_content, final_output diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index fb76276bb4b34..242690f8e1b8f 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -8,11 +8,11 @@ import torch.nn as nn from typing_extensions import TypeVar +import vllm.platforms from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import make_async @@ -108,8 +108,8 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: """ # NOTE: This is logged in the executor because there can be >1 workers. logger.info("# %s blocks: %d, # CPU blocks: %d", - current_platform.dispatch_key, num_gpu_blocks, - num_cpu_blocks) + vllm.platforms.current_platform.dispatch_key, + num_gpu_blocks, num_cpu_blocks) max_concurrency = (num_gpu_blocks * self.cache_config.block_size / self.model_config.max_model_len) logger.info("Maximum concurrency for %s tokens per request: %.2fx", diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 7b30155971a6d..33c0a25803ca6 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -7,10 +7,10 @@ import msgspec +import vllm.platforms from vllm.config import ParallelConfig from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import get_ip from vllm.worker.worker_base import WorkerWrapperBase @@ -54,10 +54,10 @@ def get_node_ip(self) -> str: def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: node_id = ray.get_runtime_context().get_node_id() - device_key = current_platform.ray_device_key + device_key = vllm.platforms.current_platform.ray_device_key if not device_key: raise RuntimeError("current platform %s does not support ray.", - current_platform.device_name) + vllm.platforms.current_platform.device_name) gpu_ids = ray.get_runtime_context().get_accelerator_ids( )[device_key] return node_id, gpu_ids diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index dcb4a8f27c252..e5464cafaecbf 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -101,7 +101,7 @@ def _init_executor(self) -> None: # - MASTER_PORT distributed_init_method = "env://" rank = int(os.environ["RANK"]) - local_rank = rank + local_rank = int(os.environ["LOCAL_RANK"]) is_driver_worker = True kwargs = dict( vllm_config=self.vllm_config, diff --git a/vllm/lora/punica_wrapper/punica_hpu.py b/vllm/lora/punica_wrapper/punica_hpu.py index 51e1bfab3f513..3661a7214648a 100644 --- a/vllm/lora/punica_wrapper/punica_hpu.py +++ b/vllm/lora/punica_wrapper/punica_hpu.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple, Union, final +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final import torch from vllm_hpu_extension.ops import (dispatch_bgmv_embedding, dispatch_bgmv_linear) from .punica_base import PunicaWrapperBase +from .utils import convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext @final @@ -19,6 +25,55 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens, max_batches, device) + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size, + extra_vocab_size, self.device, None) + # Updating each element in `long_lora_offsets` with `lora_offset` slows + # down perf in HPU due to a series of `strided_insert` ops during lazy + # graph accumulation. Hence HPU appends `lora_offset` to a list and + # converts it to a tensor only after it is ready. + if long_lora_context: + index_mapping_indices: List[int] = list( + mapping.index_mapping).copy() + long_lora_offsets: List[int] = [] + for i in range(len(index_mapping_indices)): + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets.append(lora_offset) + long_lora_offsets_tensor = torch.tensor(long_lora_offsets, + device=self.device, + dtype=torch.long) + indices_len[-1] = long_lora_offsets_tensor.shape[-1] + + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + self.indices_len[:] = indices_len + def add_lora_embedding(self, y: torch.Tensor, x: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py new file mode 100644 index 0000000000000..5fd1264910231 --- /dev/null +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -0,0 +1,534 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionMetadata) +from vllm.attention.backends.xformers import XFormersMetadata +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_state_update) +from vllm.model_executor.layers.mamba.ops.ssd_combined import ( + mamba_chunk_scan_combined) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import ( + LoaderFunction, composed_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.mamba_cache import MambaCacheParams +from vllm.model_executor.utils import set_weight_attrs + +# Added by the IBM Team, 2024 + + +# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated +@CustomOp.register("mixer2_gated_rms_norm") +class Mixer2RMSNormGated(CustomOp): + + def __init__(self, full_hidden_size, full_n_groups, eps=1e-6): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.full_hidden_size = full_hidden_size + self.group_size = full_hidden_size // full_n_groups + self.per_rank_hidden_size = full_hidden_size // self.tp_size + self.n_groups = full_hidden_size // self.group_size + + self.variance_epsilon = eps + self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) + set_weight_attrs(self.weight, + {"weight_loader": sharded_weight_loader(0)}) + assert self.full_hidden_size % self.tp_size== 0,\ + "Tensor parallel world size must divide hidden size." + + def forward_native( + self, + x: torch.Tensor, + gate: torch.Tensor, + ): + # Three tensor-parallel cases: + # 1. n_groups is 1 + # In this case we parallelize along the reduction dim. + # Each rank computes a local sum of squares followed by AllReduce + # 2. tp_size divides n_groups + # Each rank only reduces within its local group(s). + # No collective ops necessary. + # 3. The general case can be pretty complicated so we AllGather + # the input and then redundantly compute the RMSNorm. + input_dtype = x.dtype + x = x * nn.functional.silu(gate.to(torch.float32)) + + if self.n_groups == 1: + if self.tp_size > 1: + # Compute local sum and then reduce to obtain global sum + local_sums = x.pow(2).sum(dim=-1, keepdim=True) + global_sums = tensor_model_parallel_all_reduce(local_sums) + # Calculate the variance + count = self.tp_size * x.shape[-1] + variance = (global_sums / count) + + else: + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + else: + redundant_tp: bool = self.n_groups % self.tp_size != 0 + if redundant_tp: + # To handle the general case, redundantly apply the variance + x = tensor_model_parallel_all_gather(x, -1) + + *prefix_dims, hidden_dim = x.shape + group_count = hidden_dim // self.group_size + x_grouped = x.view(*prefix_dims, group_count, self.group_size) + variance = x_grouped.pow(2).mean(-1, keepdim=True) + x_grouped = x_grouped * torch.rsqrt(variance + + self.variance_epsilon) + x = x_grouped.view(*prefix_dims, hidden_dim) + + if redundant_tp: + start = self.per_rank_hidden_size * self.tp_rank + end = start + self.per_rank_hidden_size + x = x[..., start:end] + + return self.weight * x.to(input_dtype) + + def forward_cuda( + self, + x: torch.Tensor, + gate: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + if self.tp_size > 1 or self.n_groups != 1: + return self.forward_native(x, gate) + + from vllm import _custom_ops as ops + + # cast x and gate to float32 before silu + out = torch.empty_like(x) + y = x * nn.functional.silu(gate.to(torch.float32)) + ops.rms_norm( + out, + y.to(x.dtype), + self.weight.data, + self.variance_epsilon, + ) + return out + + +def extra_groups_for_head_shards(ngroups: int, tp_size: int): + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + return tp_size - ngroups % tp_size + + +def mamba_v2_sharded_weight_loader( + shard_spec: List[Tuple[int, int, float]], + tp_size: int, + tp_rank: int, +) -> LoaderFunction: + """Create a weight loader for mamba v2. This ensures that the projections + are correctly sharded so that they can be split into x, B, C. It also + ensures the the all the groups corresponding to a head shard is placed + together with it. + """ + + def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + + # - track boundary of (sharded) param, and loaded_weight, respectively + boundary, loaded_boundary = 0, 0 + + # - iterate over the shard specs + for full_dim, extra, ratio in shard_spec: + # - full dim is the model dim (before TP). + # - extra > 0, means there is expected overall increase + # of dimensions. This is so because of replication. + # - ratio is used map the tp_rank to the actual shard + # rank. This is useful when there is replication of + # groups to accompany head shards. + + # - size of the loaded shard + shard_size = full_dim // tp_size + + # - compute the rank into the loaded shard. + # - if there is replication, different TP shards will + # take from the same rank. + rank = tp_rank // ratio + + # - leftmost boundary index into loaded weight. + loaded_skip = rank * shard_size + loaded_start_idx = loaded_boundary + loaded_skip + + # - take these many dims from the loaded weight. + take = min(shard_size, full_dim - extra - loaded_skip) + + # - always shard on dim 0 + # - the ignore is for a mundane mypy error as it does not + # seem to handle slices well. + # https://github.com/python/mypy/issues/2410 + param.data[ + boundary:(boundary + take), # type: ignore[misc] + ...] = loaded_weight[loaded_start_idx:( # type: ignore[misc] + loaded_start_idx + take)] # type: ignore[misc] + + # move indexing boundaries + boundary += shard_size + loaded_boundary += (full_dim - extra) + + return loader + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +@CustomOp.register("mamba_mixer2") +class MambaMixer2(CustomOp): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation="silu", + chunk_size: int = 256, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + + # For TP, the sharding plan is as follows: + # - for the conv modules, since + # conv_dim = intermediate_size * 2 * n_groups * ssm_state_size, + # we shard intermediate_size and n_groups + # - since intermediate_size = n_heads * head_dim, sharding on + # intermediate_size is achieved by sharding on n_heads. + # - IF, world_size divides groups, then sharding + # (n_groups / world_size, n_heads / world_size) + # also maintains the invariant n_heads % n_groups == 0 + # - HOWEVER IF, world_size DOES NOT divide groups, then we need + # to allocate extra space in the shard, such that groups + # may be replicated to follow the head shard. + self.tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + assert num_heads % self.tp_size == 0, \ + "Tensor parallel world size must divide num heads." + + self.ssm_state_size = ssm_state_size + self.activation = activation + + self.chunk_size = chunk_size + self.intermediate_size = intermediate_size + self.head_dim = head_dim + self.num_heads = num_heads + + self.n_groups = n_groups + if n_groups % self.tp_size != 0: + # - for TP we shard conv_dim by sharding on n_groups, + # - but if n_groups cannot divide tp_size, we need to + # extend some extra groups + self.n_groups = n_groups + extra_groups_for_head_shards( + n_groups, self.tp_size) + + self.conv_dim = (intermediate_size + + 2 * self.n_groups * ssm_state_size) + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=self.conv_dim, + bias=use_conv_bias, + quant_config=None, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = ColumnParallelLinear(input_size=hidden_size, + output_size=intermediate_size + + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config) + + # - because in_proj is a concatenation of 3 weights, we + # need to interleave them before sharding + # - use the custom weight loader mamba_v2_sharded_weight_loader + # for conv1d.bias, covn1d.weight and in_proj.weight + # - need to set these settings, to assign the groups to the head shards + group_shard_settings = ( + self.n_groups * self.ssm_state_size, # expected model size + (self.n_groups - n_groups) * + self.ssm_state_size, # extra dims assigned + self.num_heads // + n_groups, # ratio for mapping back to original group + ) + intermediate_settings = (intermediate_size, 0, 1) + head_setings = (self.num_heads, 0, 1) + + # - the weight already has a "weight_loader" attribute + # which set_weight_attrs will raise if we do not + # delete before trying to override it + # - ditto for the otther two weights below + delattr(self.conv1d.bias, "weight_loader") + set_weight_attrs( + self.conv1d.bias, { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, { + "weight_loader": + mamba_v2_sharded_weight_loader([ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], self.tp_size, tp_rank) + }) + + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs( + self.in_proj.weight, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, # for gate + intermediate_settings, + group_shard_settings, + group_shard_settings, + head_setings, # for dt + ], + self.tp_size, + tp_rank) + }) + + # - these are TPed by heads to reduce the size of the + # temporal shape + self.A = nn.Parameter( + torch.empty( + divide(num_heads, self.tp_size), + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) + self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) + + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + set_weight_attrs(self.dt_bias, + {"weight_loader": sharded_weight_loader(0)}) + + self.out_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=use_bias, + input_is_parallel=True, + quant_config=quant_config) + + self.norm = Mixer2RMSNormGated(intermediate_size, + n_groups, + eps=rms_norm_eps) + + def forward_native(self, hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, ssm_state: torch.Tensor): + pass + + def forward_cuda( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor] = None, + ): + + seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + + # detect if there are prefills + has_prefill = attn_metadata.num_prefills > 0 + + # - also need flags to indicate if there are initial states + # - currently we really only support the FlashAttention backend + has_initial_states = None + if (isinstance(attn_metadata, + (FlashAttentionMetadata, XFormersMetadata, + PlaceholderAttentionMetadata)) + and attn_metadata.context_lens_tensor is not None): + has_initial_states = attn_metadata.context_lens_tensor > 0 + + # 1. Gated MLP's linear projection + projected_states, _ = self.in_proj(hidden_states) + gate, hidden_states_B_C, dt = torch.split( + projected_states, + [ + self.intermediate_size // self.tp_size, + self.conv_dim // self.tp_size, + self.num_heads // self.tp_size, + ], + dim=-1, + ) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if has_prefill: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" + hidden_states_B_C = causal_conv1d_fn( + hidden_states_B_C.transpose(0, 1), + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=has_initial_states, + cache_indices=mamba_cache_params.state_indices_tensor, + query_start_loc=attn_metadata.query_start_loc).transpose( + 0, 1)[:seq_len] + + # TODO: Why is this needed? + hidden_states_B_C = hidden_states_B_C.contiguous() + else: + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor) + + # - get hidden_states, B and C after depthwise convolution. + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size // self.tp_size, + groups_time_state_size // self.tp_size, + groups_time_state_size // self.tp_size, + ], + dim=-1, + ) + + # 3. State Space Model sequence transformation + if has_prefill: + + initial_states = None + if has_initial_states is not None and any(has_initial_states): + for idx in mamba_cache_params.state_indices_tensor[ + ~has_initial_states]: + mamba_cache_params.ssm_state[idx].zero_() + initial_states = mamba_cache_params.ssm_state[ + mamba_cache_params.state_indices_tensor] + + scan_output, varlen_state = mamba_chunk_scan_combined( + hidden_states.view(1, seq_len, self.num_heads // self.tp_size, + self.head_dim), + dt.unsqueeze(0), + self.A, + B.view(1, seq_len, self.n_groups // self.tp_size, -1), + C.view(1, seq_len, self.n_groups // self.tp_size, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + dt_bias=self.dt_bias, + seq_idx=sequence_idx, + cu_seqlens=attn_metadata.query_start_loc, + initial_states=initial_states, + return_varlen_states=True, + return_final_states=False, + dt_softplus=True, + dt_limit=(0.0, float("inf")), + ) + + # update ssm states + # - varlen state is a (batch, nheads, headdim, dstate) tensor + for i, idx in enumerate(mamba_cache_params.state_indices_tensor): + mamba_cache_params.ssm_state[idx].copy_(varlen_state[i]) + + # - reshape + hidden_states = scan_output.view(seq_len, -1) + else: + + n_groups = self.n_groups // self.tp_size + A = self.A[:, None, ...][:, :, None].expand( + -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(-1, n_groups, B.shape[1] // n_groups) + C = C.view(-1, n_groups, C.shape[1] // n_groups) + hidden_states_reshaped = hidden_states.view( + -1, self.num_heads // self.tp_size, self.head_dim) + + # - the hidden is reshaped into number of current batches + # - in this case there is no more prefill, so the batches gen + # 1 token at a time + # - thus hidden will be (bs, num_heads, head_dim) + # - mamba_cache_params.ssm_state's slots will be selected + # using "mamba_cache_params.state_indices_tensor", just as + # above in the prefill case + + hidden_states = selective_state_update( + mamba_cache_params.ssm_state, + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor, + ) + hidden_states = hidden_states.view( + -1, (self.num_heads // self.tp_size) * self.head_dim) + + # # 4. gated MLP + hidden_states = self.norm(hidden_states, gate) + + # # 5. Final linear projection + out, _ = self.out_proj(hidden_states) + return out diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 3c35f1ac0dcf5..b31b980fbe84a 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2024, Tri Dao, Albert Gu. -# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py import torch import triton diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py new file mode 100644 index 0000000000000..388a63327213b --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py + +# ruff: noqa: E501,SIM102 + +import math + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['chunk_size', 'K', 'IS_CAUSAL'], +) +@triton.jit +def _bmm_chunk_fwd_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + out_ptr, + seq_idx_ptr, + # Matrix dimensions + seqlen, + chunk_size, + K, + ngroups, + stride_a_batch, + stride_a_seqlen, + stride_a_head, + stride_ak, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_bk, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_outm, + stride_outn, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + dot_dtype: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_ch = tl.program_id(axis=2).to(tl.int64) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + if IS_CAUSAL: + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + return + a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + + offs_n[None, :] * stride_b_seqlen) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0).to(dot_dtype) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & + (offs_n[None, :] < chunk_size_limit), + other=0.0).to(dot_dtype) + acc += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, + mask=offs_n < chunk_size_limit, + other=-2) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + out = acc.to(out_ptr.dtype.element_ty) + + out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + + offs_n[None, :] * stride_outn) + tl.store(out_ptrs, + out, + mask=(offs_m[:, None] < chunk_size) & + (offs_n[None, :] < chunk_size)) + + +def _bmm_chunk_fwd(a, + b, + chunk_size, + seq_idx=None, + causal=False, + output_dtype=None): + """ + Argument: + a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. + causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are + guaranteed to be correct. + Return: + out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + """ + # Check constraints. + has_groups = a.dim() == 4 + if not has_groups: + batch, seqlen, k = a.shape + else: + batch, seqlen, ngroups, k = a.shape + assert b.shape == a.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if a.stride(-1) != 1 and a.stride(1) != 1: + a = a.contiguous() + if b.stride(-1) != 1 and b.stride(1) != 1: + b = b.contiguous() + nchunks = math.ceil(seqlen / chunk_size) + # Allocates output. + out_dtype = a.dtype if output_dtype is None else output_dtype + out = torch.empty( + (batch, nchunks, chunk_size, chunk_size) if not has_groups else + (batch, nchunks, ngroups, chunk_size, chunk_size), + device=a.device, + dtype=out_dtype) + dot_dtype = (tl.bfloat16 + if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else + (tl.float16 if a.dtype == torch.float16 + or b.dtype == torch.float16 else tl.float32)) + grid = lambda META: (triton.cdiv( + chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + chunk_size, META['BLOCK_SIZE_N']), batch, nchunks + if not has_groups else nchunks * ngroups) + with torch.cuda.device(a.device.index): + _bmm_chunk_fwd_kernel[grid]( + a, + b, + out, + seq_idx, + seqlen, + chunk_size, + k, + ngroups if has_groups else 1, + a.stride(0), + a.stride(1), + 0 if not has_groups else a.stride(2), + a.stride(-1), + b.stride(0), + b.stride(1), + 0 if not has_groups else b.stride(2), + b.stride(-1), + out.stride(0), + out.stride(1), + 0 if not has_groups else out.stride(2), + out.stride(-2), + out.stride(-1), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + causal, + dot_dtype, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py new file mode 100644 index 0000000000000..722fbd714ca8f --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -0,0 +1,615 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py + +# ruff: noqa: E501,SIM102 + +import math + +import torch +import triton +import triton.language as tl +from packaging import version + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], +) +@triton.jit +def _chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, + x_ptr, + z_ptr, + out_ptr, + out_x_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + C_ptr, + states_ptr, + D_ptr, + initstates_ptr, + chunk_indices_ptr, + chunk_offsets_ptr, + chunk_meta_num, + # Matrix dimensions + chunk_size, + hdim, + dstate, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_cb_batch, + stride_cb_chunk, + stride_cb_head, + stride_cb_csize_m, + stride_cb_csize_k, + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_z_batch, + stride_z_seqlen, + stride_z_head, + stride_z_hdim, + stride_out_batch, + stride_out_seqlen, + stride_out_head, + stride_out_hdim, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + stride_C_batch, + stride_C_seqlen, + stride_C_head, + stride_C_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + stride_D_head, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + if not HAS_INITSTATES: + c_idx = pid_c + c_off = 0 + else: + c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0) + c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0) + + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + ( + pid_h // nheads_ngroups_ratio) * stride_cb_head + x_ptr += pid_b * stride_x_batch + c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += pid_b * stride_C_batch + c_idx * chunk_size * stride_C_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_C_head + + # M-block offsets and prev states + # - logic in next block may override these if there is an active offset + offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate + + chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen + + # - we only need seq_idx_prev to be aligned to chunk boundary + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, + mask=c_idx >= 1, + other=0) + + if HAS_INITSTATES: + # if there are init states, we only need seq_idx_m to point + # what is the current seq_idx + + # get current seq idx + if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: + seq_idx_m = tl.load( + seq_idx_ptr + + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, ) + + # - recall that in ssd_state_passing, for the case c_off == 0 + # i.e., the very first sequence, we made states_ptr hold its initial state + # so this edge case is taken care of + if ((c_off == 0) and + (seq_idx_prev != seq_idx_m + ) # if a seq is changed exactly on boundary + or (c_off > 0) # implies a new example (pseudo chunk) + ): + + # - replace prev_states_ptr with init_states + prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head + prev_states_hdim = stride_init_states_hdim # override strides + prev_states_dstate = stride_init_states_dstate + + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, + mask=offs_m < chunk_size, + other=0.0).to(tl.float32) + + # - handle chunk state limit + if HAS_INITSTATES: + + # have to split this if otherwise compilation will have problems + dA_cs_m_boundary = 0.0 + + # get the c_idx for the next (logica) chunk + c_idx_n = tl.load( + chunk_indices_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=-1 # to trigger different chunk + ) + + # - there are things to consider + # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct + # contribution of past states + # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to + # encroach into the next sequence, where c_off_n is the offset of the next + # (logical) chunk. + # An equivalent check for B is c_idx == c_idx_n, where there is repetition in + # (logical) chunk indices. + + if (c_idx == c_idx_n) or c_off > 0: + + # get the next offset + c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=chunk_size) + + # in this case, adjust down the chunk_size_limit + if c_idx == c_idx_n: + chunk_size_limit = min(c_off_n, chunk_size_limit) + + # get the cs at the offset boundary + # - c_off == 0 is a passthrough + dA_cs_m_boundary = tl.load( + dA_cumsum_ptr + + (pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize, + mask=(pid_m * BLOCK_SIZE_M + c_off - 1) > -1, + other=0.0).to(tl.float32) + + if HAS_SEQ_IDX: + # - handle seq idx when HAS_INITSTATES==False + if not HAS_INITSTATES: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Without the if (pid_c > -1), with Triton 2.1.0, I get + # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. + # With Triton 2.2.0, this works + if IS_TRITON_22 or c_idx > -1: + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + + offs_k_dstate[None, :] * stride_C_dstate) + + prev_states_ptrs = prev_states_ptr + ( + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate) + if HAS_SEQ_IDX: + + if not HAS_INITSTATES: + # - this is for continuous batching where there is no init states + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), + 0.0) + else: + # - if there is initstates, we will rely on prev_states, no zeroing + # required. + scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) + else: + scale_m = tl.exp(dA_cs_m) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate), + other=0.0) + + prev_states = tl.load(prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc = tl.dot(C, prev_states) * scale_m[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate - k), + other=0.0) + # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + + offs_k[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = chunk_size_limit if not IS_CAUSAL else min( + (pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + for k in range(0, K_MAX, BLOCK_SIZE_K): + cb = tl.load(cb_ptrs, + mask=(offs_m[:, None] < chunk_size) & + (offs_k[None, :] < chunk_size - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) + # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. + # So we don't need masking wrt seq_idx here. + cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) + cb *= dt_k + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load(x_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < hdim), + other=0.0) + acc += tl.dot(cb, x) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, + mask=offs_n < hdim, + other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_n[None, :] < hdim), + other=0.0).to(tl.float32) + acc += x_residual * D + + if HAS_Z: + out_x_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :]) + tl.store(out_x_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) + + z_ptr += pid_b * stride_z_batch + c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + + stride_z_hdim * offs_out_n[None, :]) + z = tl.load(z_ptrs, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim), + other=0.0).to(tl.float32) + acc *= z * tl.sigmoid(z) + + out_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) + + +def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): + + # convert seq_idx to chunk indices and offsets + # - derive the cu_seqlens + _, cu_seqlens = torch.where(seq_idx.diff()) + cu_seqlens += 1 + + # outputs will have length expansion of chunks that do not divide + # chunk_size + N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size + > 0).sum() + chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device) + chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device) + + cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]] + p = 0 # num of insertions + for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): + + # if does not divide chunk_size, then there is one chunk insertion + p += (s % chunk_size > 0) + + # get the dimensions + _s, _e = s // chunk_size + p, e // chunk_size + p + 1 + + # adjust inidces and offsets + chunk_indices[_s:_e] -= p + chunk_offsets[_s] = s % chunk_size + + return chunk_indices, chunk_offsets + + +def _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + D=None, + z=None, + seq_idx=None, + initial_states=None, +): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + + chunk_indices, chunk_offsets = None, None + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + + if initial_states is not None: + # with initial states, we need to take care of how + # seq_idx crosses the boundaries + assert batch == 1, "chunk scan only supports initial states with batch 1" + assert initial_states.shape == (seq_idx[0].max() + 1, nheads, + headdim, dstate) + + if initial_states.shape[0] == 1: + # no in this case no point to use initial states + initial_states = None + else: + chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( + seq_idx, chunk_size) + + # Allocates output. + out = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) + if z is not None: + out_x = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) + assert out_x.stride() == out.stride() + else: + out_x = None + + grid = lambda META: ( + triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + headdim, META['BLOCK_SIZE_N']), batch * nchunks + if chunk_offsets is None else len(chunk_offsets), nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), + z.stride(3)) if z is not None else (0, 0, 0, 0)) + _chunk_scan_fwd_kernel[grid]( + cb, + x, + z, + out, + out_x, + dt, + dA_cumsum, + seq_idx, + C, + states, + D, + initial_states, + chunk_indices, + chunk_offsets, + len(chunk_indices) if chunk_indices is not None else 0, + chunk_size, + headdim, + dstate, + batch, + seqlen, + nheads // ngroups, + cb.stride(0), + cb.stride(1), + cb.stride(2), + cb.stride(3), + cb.stride(4), + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + z_strides[0], + z_strides[1], + z_strides[2], + z_strides[3], + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else + (0, 0)), + C.stride(0), + C.stride(1), + C.stride(2), + C.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) if initial_states is not None else + (0, 0, 0, 0)), + D.stride(0) if D is not None else 0, + True, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_TRITON_22=TRITON_22, + HAS_INITSTATES=initial_states is not None, + ) + return out, out_x diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py new file mode 100644 index 0000000000000..a970ac94580b4 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -0,0 +1,750 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py + +# ruff: noqa: E501 + +import math + +import torch +import triton +import triton.language as tl + +from .mamba_ssm import softplus + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H': 1}), + triton.Config({'BLOCK_SIZE_H': 2}), + triton.Config({'BLOCK_SIZE_H': 4}), + triton.Config({'BLOCK_SIZE_H': 8}), + triton.Config({'BLOCK_SIZE_H': 16}), + triton.Config({'BLOCK_SIZE_H': 32}), + triton.Config({'BLOCK_SIZE_H': 64}), + ], + key=['chunk_size', 'nheads'], +) +@triton.jit +def _chunk_cumsum_fwd_kernel( + # Pointers to matrices + dt_ptr, + A_ptr, + dt_bias_ptr, + dt_out_ptr, + dA_cumsum_ptr, + # Matrix dimension + batch, + seqlen, + nheads, + chunk_size, + dt_min, + dt_max, + # Strides + stride_dt_batch, + stride_dt_seqlen, + stride_dt_head, + stride_A_head, + stride_dt_bias_head, + stride_dt_out_batch, + stride_dt_out_chunk, + stride_dt_out_head, + stride_dt_out_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_CHUNK: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + + # if dt is long, may cause problems, so use 64 bit + # https://github.com/triton-lang/triton/issues/1058 + pid_c = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2) + dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + + offs_c[None, :] * stride_dt_seqlen) + A_ptrs = A_ptr + offs_h * stride_A_head + dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + + offs_c[None, :] * stride_dt_out_csize) + dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + + offs_c[None, :] * stride_dA_cs_csize) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + dt = tl.load(dt_ptrs, + mask=(offs_h[:, None] < nheads) & + (offs_c[None, :] < chunk_size_limit), + other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, + mask=offs_h < nheads, + other=0.0).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + # As of Triton 2.2.0, tl.clamp is not available yet + # dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + dt = tl.where( + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, + 0.0) + tl.store(dt_out_ptrs, + dt, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + dA = dt * A[:, None] + dA_cs = tl.cumsum(dA, axis=1) + tl.store(dA_cs_ptrs, + dA_cs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_fwd_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + states_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_last = tl.load(seq_idx_ptr + + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_k = tl.load(seq_idx_ptrs, + mask=offs_k < chunk_size_limit - k, + other=-1) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k + else: + scale = tl.where(seq_idx_k == seq_idx_last, + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_varlen_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + dt_ptr, + dA_cumsum_ptr, + chunk_states_ptr, + cu_seqlens_ptr, + states_ptr, + initstates_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_chunk_states_chunk, + stride_chunk_states_head, + stride_chunk_states_hdim, + stride_chunk_states_dstate, + stride_states_batch, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) + pid_c = (end_idx - 1) // chunk_size + b_ptr += pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + + if HAS_INITSTATES: + # if there are init states provided, we differentiate between states (which + # are boundary conditions at a chunk boundary) and initstates (which are boundary + # conditions when a new example in a cont batch starts) + initstates_ptr += pid_h * stride_init_states_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * + stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = end_idx - pid_c * chunk_size + start_idx = tl.load(cu_seqlens_ptr + pid_b) + start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k) & + (offs_k[None, :] >= start_idx_cur - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate) & + (offs_k[:, None] >= start_idx_cur - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + scale = tl.where( + (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk + # If HAS_INITSTATES==True need to consider two possiblties + # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs + # - if state_idx >= pid * chunk_size, then we need to insert initstates + if ((start_idx < pid_c * chunk_size) # first chunk + or (HAS_INITSTATES)): + + dA_cs_boundary = 0.0 # default + + if not HAS_INITSTATES: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + else: + + # - this seems repetitve, buts its to help the compiler + if start_idx < pid_c * chunk_size: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + else: + past_states_ptrs = initstates_ptr + ( + pid_b * stride_init_states_batch + + offs_m[:, None] * stride_init_states_hdim + + offs_n[None, :] * stride_init_states_dstate) + + # need to adjust the boundary + if start_idx > pid_c * chunk_size: + dA_cs_boundary = tl.load(dA_cumsum_ptr + + (start_idx - pid_c * chunk_size - + 1) * stride_dA_cs_csize).to( + tl.float32) + + past_states = tl.load(past_states_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + + scale = tl.exp(dA_cs_last - dA_cs_boundary) + acc += past_states * scale + + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads = dt.shape + assert A.shape == (nheads, ) + if dt_bias is not None: + assert dt_bias.shape == (nheads, ) + nchunks = math.ceil(seqlen / chunk_size) + dt_out = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + dA_cumsum = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + grid_chunk_cs = lambda META: (batch, nchunks, + triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_fwd_kernel[grid_chunk_cs]( + dt, + A, + dt_bias, + dt_out, + dA_cumsum, + batch, + seqlen, + nheads, + chunk_size, + dt_limit[0], + dt_limit[1], + dt.stride(0), + dt.stride(1), + dt.stride(2), + A.stride(0), + dt_bias.stride(0) if dt_bias is not None else 0, + dt_out.stride(0), + dt_out.stride(2), + dt_out.stride(1), + dt_out.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return dA_cumsum, dt_out + + +def _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=None, + states=None, + states_in_fp32=True): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if states is not None: + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + else: + states_dtype = torch.float32 if states_in_fp32 else B.dtype + states = torch.empty((batch, nchunks, nheads, headdim, dstate), + device=x.device, + dtype=states_dtype) + grid = lambda META: ( + triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( + dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_fwd_kernel[grid]( + x, + B, + states, + dt, + dA_cumsum, + seq_idx, + headdim, + dstate, + chunk_size, + batch, + seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + B.stride(0), + B.stride(1), + B.stride(2), + B.stride(-1), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_SEQ_IDX=seq_idx is not None, + ) + return states + + +def chunk_state_varlen(B, + x, + dt, + dA_cumsum, + cu_seqlens, + chunk_states, + initial_states=None): + total_seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + assert nheads % ngroups == 0 + assert B.shape == (total_seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert chunk_states.shape == (nchunks, nheads, headdim, dstate) + + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + + states = torch.empty(batch, + nheads, + headdim, + dstate, + dtype=chunk_states.dtype, + device=chunk_states.device) + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton. + cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_varlen_kernel[grid]( + x, + B, + dt, + dA_cumsum, + chunk_states, + cu_seqlens, + states, + initial_states, + headdim, + dstate, + chunk_size, + total_seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + dt.stride(1), + dt.stride(0), + dt.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + chunk_states.stride(0), + chunk_states.stride(1), + chunk_states.stride(2), + chunk_states.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) if initial_states is not None else + (0, 0, 0, 0)), + HAS_INITSTATES=initial_states is not None) + return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py new file mode 100644 index 0000000000000..97cdb70b63cc6 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py + +# ruff: noqa: E501 + +import torch +import triton +from einops import rearrange +from packaging import version + +from .ssd_bmm import _bmm_chunk_fwd +from .ssd_chunk_scan import _chunk_scan_fwd +from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, + chunk_state_varlen) +from .ssd_state_passing import _state_passing_fwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +def _mamba_chunk_scan_combined_fwd(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads, ) + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if x.stride(-1) != 1 and x.stride( + 1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + if z is not None and z.stride(-1) != 1 and z.stride( + 1) != 1: # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + if initial_states is not None: + if cu_seqlens is None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + else: + assert initial_states.shape == (len(cu_seqlens) - 1, nheads, + headdim, dstate) + + # This function executes 5 sub-functions for computing mamba + # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ + # which has a minimal implementation to understand the below operations + # - as explained by the blog, mamba is a special case of causal attention + # - the idea is to chunk the attention matrix and compute each + # submatrix separately using different optimizations. + # - see the blog and paper for a visualization of the submatrices + # which we refer to in the comments below + + # 1. Compute chunked cumsum of A * dt + # - here dt may go through a softplus activation + dA_cumsum, dt = _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + states = _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=seq_idx, + states_in_fp32=True) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + # - for handling chunked prefill, this requires i) initial_states + # ii) seq_idx and iii) has_cu_seqlens to be all specified. + # - When a new seq_idx is detected, we will stop passing the prev_state + # and switch accordingly to the init_state corresponding to the new seq_idx. + # - this will ensure that states will be updated with the rightmost flushed seq_idx + # of the previous chunk. This implies that the first chunk of states is either 0 + # or equal to init_states of the first example. + states, final_states = _state_passing_fwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") + if initial_states is not None else None, + seq_idx=seq_idx, + chunk_size=chunk_size, + out_dtype=C.dtype, + is_cont_batched=cu_seqlens is not None) + states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) + for t in [states, final_states]) + + # 4. Compute batched matrix multiply for C_j^T B_i terms + CB = _bmm_chunk_fwd(C, + B, + chunk_size, + seq_idx=seq_idx, + output_dtype=torch.float32) + + # 5. Scan and compute the diagonal blocks, taking into + # account past causal states. + # - if initial states are provided, then states information will be + # augmented with initial_states. + # - to do this properly, we need to account for example changes in + # the continuous batch, therefore we introduce pseudo chunks, which is + # a chunk that is split up each time an example changes. + # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had + # a seq_idx change, in which case we take states information from + # init_states. + out, out_x = _chunk_scan_fwd( + CB, + x, + dt, + dA_cumsum, + C, + states, + D=D, + z=z, + seq_idx=seq_idx, + initial_states=initial_states, + ) + if cu_seqlens is None: + return out, out_x, dt, dA_cumsum, states, final_states + else: + assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" + varlen_states = chunk_state_varlen( + B.squeeze(0), + x.squeeze(0), + dt.squeeze(0), + dA_cumsum.squeeze(0), + cu_seqlens, + states.squeeze(0), + initial_states=initial_states, + ) + return out, out_x, dt, dA_cumsum, states, final_states, varlen_states + + +def mamba_chunk_scan_combined(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_final_states=False, + return_varlen_states=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + chunk_size: int + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + initial_states: (batch, nheads, headdim, dstate) + seq_idx: (batch, seqlen) + cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True + dt_softplus: Whether to apply softplus to dt + Return: + out: (batch, seqlen, nheads, headdim) + """ + + if not return_varlen_states: + cu_seqlens = None + else: + assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" + out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=D, + z=z, + dt_bias=dt_bias, + initial_states=initial_states, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + dt_softplus=dt_softplus, + dt_limit=dt_limit) + if not return_varlen_states: + return out if not return_final_states else (out, final_states) + else: + varlen_states = rest[0] + return (out, + varlen_states) if not return_final_states else (out, + final_states, + varlen_states) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py new file mode 100644 index 0000000000000..d8f87c113f168 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py + +# ruff: noqa: E501 + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64}), + triton.Config({'BLOCK_SIZE': 128}), + triton.Config({'BLOCK_SIZE': 256}), + triton.Config({'BLOCK_SIZE': 512}), + triton.Config({'BLOCK_SIZE': 1024}), + triton.Config({'BLOCK_SIZE': 2048}), + ], + key=['dim'], +) +@triton.jit +def _state_passing_fwd_kernel( + # Pointers to matrices + states_ptr, + out_ptr, + final_states_ptr, + dA_cs_ptr, + initstates_ptr, + seq_idx_ptr, + # Matrix dimensions + dim, + nchunks, + seqlen, + chunk_size, + # Strides + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_dim, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_out_dim, + stride_final_states_batch, + stride_final_states_head, + stride_final_states_dim, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_initstates_batch, + stride_initstates_head, + stride_initstates_dim, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_INITSTATES: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + IS_CONT_BATCHED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head + if HAS_INITSTATES: + initstates_ptr += pid_h * stride_initstates_head + if not IS_CONT_BATCHED: + initstates_ptr += pid_b * stride_initstates_batch + + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + states_ptrs = states_ptr + offs_m * stride_states_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim + + # - states will be the past state of the sequence that continues on the current check + if not HAS_INITSTATES: + states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + else: + initstates_ptr += offs_m * stride_initstates_dim + initstates_ptrs = initstates_ptr + # - for cont batches, for the first chunk mean it will be the first batch's + # init state + states = tl.load(initstates_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + + tl.store(out_ptrs, states, mask=offs_m < dim) + out_ptrs += stride_out_chunk + seq_idx = 0 + for c in range(nchunks): + new_states = tl.load(states_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale = tl.exp(dA_cs) + if HAS_SEQ_IDX: + # - the seq to pass forward is the one that is flushed to the right + # boundary. + # - that is given by seq_idx_new below. + seq_idx_new = tl.load(seq_idx_ptr + + (min((c + 1) * chunk_size, seqlen) - 1) * + stride_seq_idx_seqlen) + if HAS_INITSTATES: + if IS_CONT_BATCHED and seq_idx != seq_idx_new: + # this means in the current chunk the rightmost flushed seq + # has changed. + # - so we do not propagate the state from previous chunk + # - but rather we load that sequence's init state + initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch + + # - update state with seq_idx_new's init state + states = tl.load(initstates_ptrs, + mask=offs_m < dim, + other=0.0).to(tl.float32) + else: + scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + + seq_idx = seq_idx_new + states = scale * states + new_states + if c < nchunks - 1: + tl.store(out_ptrs, states, mask=offs_m < dim) + else: + tl.store(final_states_ptrs, states, mask=offs_m < dim) + states_ptrs += stride_states_chunk + dA_cs_ptr += stride_dA_cs_chunk + out_ptrs += stride_out_chunk + + +def _state_passing_fwd( + states, + dA_chunk_cumsum, + initial_states=None, + seq_idx=None, + chunk_size=None, + out_dtype=None, + is_cont_batched=False, +): + batch, nchunks, nheads, dim = states.shape + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if initial_states is not None: + if is_cont_batched: + # - if cu_seqlens is provided, then the initial states + # are used for continuous batching. In which case we + # require seq_idx to be provided + assert seq_idx is not None, "" + assert initial_states.shape == (seq_idx.max().item() + 1, nheads, + dim) + else: + # - this is the regular batching case, where initial + # states are used are for each example of the batch. + assert initial_states.shape == (batch, nheads, dim) + + if seq_idx is not None: + assert chunk_size is not None + seqlen = seq_idx.shape[-1] + assert seq_idx.shape == (batch, seqlen) + out_dtype = states.dtype if out_dtype is None else out_dtype + out = torch.empty((batch, nchunks, nheads, dim), + device=states.device, + dtype=out_dtype) + final_states = torch.empty((batch, nheads, dim), + device=states.device, + dtype=torch.float32) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + with torch.cuda.device(states.device.index): + _state_passing_fwd_kernel[grid]( + states, + out, + final_states, + dA_chunk_cumsum, + initial_states, + seq_idx, + dim, + nchunks, + seqlen if seq_idx is not None else 0, + chunk_size if seq_idx is not None else 0, + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + final_states.stride(0), + final_states.stride(1), + final_states.stride(2), + dA_chunk_cumsum.stride(0), + dA_chunk_cumsum.stride(2), + dA_chunk_cumsum.stride(1), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2)) if initial_states is not None else + (0, 0, 0)), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_INITSTATES=initial_states is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_CONT_BATCHED=is_cont_batched, + ) + return out, final_states diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 6ded3874fc1dd..6cd508d057a44 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -11,6 +11,7 @@ "deepspeedfp", "tpu_int8", "fp8", + "ptpc_fp8", "fbgemm_fp8", "modelopt", # The order of gptq methods is important for config.py iteration over @@ -99,6 +100,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .modelopt import ModelOptFp8Config from .moe_wna16 import MoeWNA16Config from .neuron_quant import NeuronQuantConfig + from .ptpc_fp8 import PTPCFp8Config from .qqq import QQQConfig from .tpu_int8 import Int8TpuConfig @@ -120,6 +122,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "gptq": GPTQConfig, "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, + "ptpc_fp8": PTPCFp8Config, "qqq": QQQConfig, "hqq": HQQMarlinConfig, "experts_int8": ExpertsInt8Config, diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py new file mode 100644 index 0000000000000..1ded5389e5f45 --- /dev/null +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase) +from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, + Fp8KVCacheMethod, + Fp8LinearMethod) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear) +from vllm.platforms import current_platform + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +class PTPCFp8Config(Fp8Config): + """Config class for Per-Token-Per-Channel Dynamic Quantization Fp8.""" + + def __init__( + self, + activation_scheme: str = "dynamic", + ignored_layers: Optional[List[str]] = None, + ) -> None: + if not current_platform.is_rocm(): + raise ValueError( + "ptpc_fp8 quantization is supported only on ROCm.") + + if not current_platform.has_device_capability(94): + raise ValueError( + "ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501 + ) + if activation_scheme == "static": + raise ValueError( + "ptpc_fp8 as of now only support dynamic quantization.") + + super().__init__(is_checkpoint_fp8_serialized=False, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers) + + @classmethod + def get_name(cls) -> str: + return "ptpc_fp8" + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config": + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + return cls(activation_scheme=activation_scheme, + ignored_layers=ignored_layers) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() + return PTPCFp8LinearMethod(self) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + + +class PTPCFp8LinearMethod(Fp8LinearMethod): + """Linear method for Per-Token and Per-Channel FP8 Quantization. + Only supports loading quantized BF16 model checkpoints with dynamic + activation scaling. To load FP16 model checkpoints, user must specify + to convert the FP16 model weight loading into BF16. + The weight scaling factor will be initialized after + the model weights are loaded. + + Limitations: + 1. Only support float8_e4m3fnuz data type due to the limitation of + torch._scaled_mm (https://github.com/ROCm/pytorch/blob/8c0504d7f3fb0ee4c278c096a5c3caedb01129fa/aten/src/ATen/native/cuda/Blas.cpp#L1041) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: PTPCFp8Config): + super().__init__(quant_config=quant_config) + # Force weight quantization + self.quant_config.is_checkpoint_fp8_serialized = False + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, + requires_grad=False) + + assert layer.weight.data.dtype == torch.bfloat16, \ + f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501 + # Quantize the weights. + qweight, weight_scale = ops.scaled_fp8_quant( + layer.weight, scale=None, use_per_token_if_dynamic=True) + + # Update the layer with the new values. + layer.weight = Parameter( + qweight.t(), requires_grad=False) # Pretranspose the weight + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return apply_fp8_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + input_scale_ub=None, + bias=bias, + cutlass_fp8_supported=False, + use_per_token_if_dynamic=True) diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index dedeb0c296bd4..bea6390f71ff7 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -11,6 +11,13 @@ # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) +# The condition to determine if it is on a platform that supports +# torch._scaled_mm rowwise feature. +# The condition is determined once as the operations +# are time consuming. +USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm() + and current_platform.has_device_capability(94)) + def sparse_cutlass_supported() -> bool: if not current_platform.is_cuda(): @@ -172,6 +179,26 @@ def apply_fp8_linear( return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + elif (use_per_token_if_dynamic and not per_tensor_weights + and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM): + # For now validated on ROCm platform + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using + # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. + # For CUDA platform please validate if the + # torch._scaled_mm support rowwise scaled GEMM + # Fused GEMM_DQ Rowwise GEMM + output = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale.t(), + bias=bias) + + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + output = output.view(*output_shape) + return output + else: # Fallback for channelwise case, where we use unfused DQ # due to limitations with scaled_mm diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index b3b9b0e876057..5d7f9396c20b0 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -206,9 +206,10 @@ def forward_hpu( ) -> Tuple[torch.Tensor, torch.Tensor]: from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb) - positions = positions.flatten() if offsets is not None: + offsets = offsets.view(positions.shape[0], -1) positions = positions + offsets + positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions).view( num_tokens, 1, -1) @@ -509,15 +510,12 @@ def __init__( ): super().__init__() - if rotary_dim != head_size: - raise ValueError( - f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \ - rotary_dim != head_size ({rotary_dim}!={head_size}).") if is_neox_style is False: raise ValueError( "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style." ) + self.rotary_dim = rotary_dim self.head_size = head_size self.max_position_embeddings = max_position_embeddings self.original_max_position_embeddings = original_max_position_embeddings @@ -557,7 +555,7 @@ def __init__( def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor: rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( - 0, self.head_size, 2, dtype=torch.float) / self.head_size))) + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))) return inv_freq def _compute_cos_sin_cache( @@ -596,8 +594,15 @@ def forward( cos = cos.repeat(1, 2).unsqueeze(-2) sin = sin.repeat(1, 2).unsqueeze(-2) - query = query * cos + _rotate_neox(query) * sin - key = key * cos + _rotate_neox(key) * sin + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = query_rot * cos + _rotate_neox(query_rot) * sin + query = torch.cat((query_rot, query_pass), dim=-1) + + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = key_rot * cos + _rotate_neox(key_rot) * sin + key = torch.cat((key_rot, key_pass), dim=-1) return query.flatten(-2), key.flatten(-2) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index cade0a1dd5950..8b2c5610f1f91 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -6,6 +6,7 @@ import json import os import tempfile +import time from collections import defaultdict from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union @@ -14,7 +15,8 @@ import huggingface_hub.constants import numpy as np import torch -from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download +from huggingface_hub import (HfFileSystem, hf_hub_download, scan_cache_dir, + snapshot_download) from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm @@ -253,6 +255,8 @@ def download_weights_from_hf( # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. with get_lock(model_name_or_path, cache_dir): + start_size = scan_cache_dir().size_on_disk + start_time = time.perf_counter() hf_folder = snapshot_download( model_name_or_path, allow_patterns=allow_patterns, @@ -262,6 +266,11 @@ def download_weights_from_hf( revision=revision, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, ) + end_time = time.perf_counter() + end_size = scan_cache_dir().size_on_disk + if end_size != start_size: + logger.info("Time took to download weights for %s: %.6f seconds", + model_name_or_path, end_time - start_time) return hf_folder @@ -453,7 +462,6 @@ def pt_weights_iterator( state = torch.load(bin_file, map_location="cpu", weights_only=True) yield from state.items() del state - torch.cuda.empty_cache() def get_gguf_extra_tensor_names( diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py new file mode 100644 index 0000000000000..72b74e31b6cc8 --- /dev/null +++ b/vllm/model_executor/models/bamba.py @@ -0,0 +1,592 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only Bamba model.""" +# Added by the IBM Team, 2024 +from typing import Iterable, List, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import BambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class BambaMLP(nn.Module): + + def __init__( + self, + config: BambaConfig, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=config.hidden_size, + output_sizes=[config.intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=bias, + quant_config=quant_config, + ) + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class BambaMixerDecoderLayer(nn.Module): + + def __init__(self, + config: BambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.mamba = MambaMixer2(hidden_size= config.hidden_size, + ssm_state_size = config.mamba_d_state, + conv_kernel_size = config.mamba_d_conv, + intermediate_size = config.mamba_expand *\ + config.hidden_size, + use_conv_bias = config.mamba_conv_bias, + use_bias = config.mamba_proj_bias, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + chunk_size=config.mamba_chunk_size, + quant_config=quant_config) + + self.feed_forward = BambaMLP(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor] = None, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.mamba(hidden_states, attn_metadata, + mamba_cache_params, sequence_idx) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class BambaAttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: BambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if hasattr(config, "partial_rotary_factor"): + rotary_dim = self.head_dim * config.partial_rotary_factor + elif hasattr(config, "attn_rotary_emb"): + rotary_dim = config.attn_rotary_emb # for backward compatibility + else: + rotary_dim = self.head_dim # default + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + rope_scaling=rope_scaling, + base=rope_theta, + is_neox_style=True, + dtype=torch.get_default_dtype(), # see impl of get_rope + ) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + + self.feed_forward = BambaMLP(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": BambaAttentionDecoderLayer, + "mamba": BambaMixerDecoderLayer +} + + +class BambaModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + layer_class = ALL_DECODER_LAYER_TYPES[ + config.layers_block_type[layer_idx]] + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + seq_idx = None + if attn_metadata.num_prefills > 0: + seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) + for i, (srt, end) in enumerate( + zip( + attn_metadata.query_start_loc, + attn_metadata.query_start_loc[1:], + )): + seq_idx[srt:end] = i + seq_idx.unsqueeze_(0) + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + residual = None + num_attn = 0 + for i in range(len(self.layers)): + layer = self.layers[i] + kv_cache = None + if isinstance(layer, BambaAttentionDecoderLayer): + kv_cache = kv_caches[num_attn] + num_attn += 1 + + layer_mamba_cache_params = None + if isinstance(layer, BambaMixerDecoderLayer): + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + i - num_attn) + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=residual, + mamba_cache_params=layer_mamba_cache_params, + sequence_idx=seq_idx, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.final_layernorm(hidden_states, residual) + return hidden_states + + +class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"] + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Bamba currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = BambaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # follow jamba + if self.scheduler_config is not None and \ + not self.model_config.enforce_eager: + # for compilation + if self.scheduler_config.max_num_seqs > \ + vllm_config.compilation_config.max_capture_size: + self.max_batch_size = \ + vllm_config.compilation_config.max_capture_size + else: + self.max_batch_size = vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + elif self.scheduler_config is not None: + # for eager just take the scheduler_config if avail + self.max_batch_size = self.scheduler_config.max_num_seqs + else: + self.max_batch_size = 8192 + 2 + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + if self.mamba_cache is None: + + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, num_mamba_layers, + self.max_batch_size, *self._get_mamba_cache_shape()) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, mamba_cache_params, + intermediate_tensors, inputs_embeds) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = self.config.mamba_expand * hidden_size + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards( + self.config.mamba_n_groups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + + 2 * n_groups * self.config.mamba_d_state) + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.mamba_d_conv - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.config.mamba_n_heads, world_size), + self.config.mamba_d_head, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index a316486752590..153c85cfb2141 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -4,21 +4,21 @@ # https://github.com/THUDM/CogAgent """Inference-only CogAgent model compatible with THUDM weights.""" from argparse import Namespace -from array import array -from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple, - TypedDict) +from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, + Union) import torch -from PIL import Image from torch import nn from torch.nn import LayerNorm +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from transformers import PreTrainedTokenizer, TensorType +from transformers.image_utils import ImageInput +from transformers.tokenization_utils_base import TextInput from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) -from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -35,73 +35,20 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs, - NestedTensors) -from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, BatchFeature, + MultiModalFieldConfig, + PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) - -logger = init_logger(__name__) - - -def calculate_image_placeholder(vision_config): - return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2 - - -def mm_input_mapper_for_glmv( - ctx: InputContext, - data: ModalityData[object], -) -> Dict: - model_config = ctx.model_config - tokenizer = cached_get_tokenizer( - model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code) - if tokenizer is None: - raise RuntimeError("No HuggingFace processor is available " - "to process the image object") - try: - raw_batch_data = tokenizer.apply_chat_template( - conversation=[{ - "role": "user", - "image": data - }], - add_generation_prompt=True, - tokenize=True, - return_tensors="pt", - return_dict=True).data - except Exception: - logger.error("Failed to process image (%s)", data) - raise - pixel_values = raw_batch_data['images'] - - return MultiModalKwargs({'pixel_values': pixel_values}) - - -def merge_glm_vision_embeddings( - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - vision_embeddings: torch.Tensor, - boi_token_id: int, - eoi_token_id: int, -) -> torch.Tensor: - - boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0] - eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0] - - mask = torch.zeros_like(input_ids, dtype=torch.bool) - - for boi_pos, eoi_pos in zip(boi_positions, eoi_positions): - assert boi_pos < eoi_pos - mask[boi_pos:eoi_pos + 1] = True - inputs_embeds[mask] = vision_embeddings.view(-1, - vision_embeddings.shape[-1]) - return inputs_embeds + maybe_prefix, merge_multimodal_embeddings) class GLMImagePixelInputs(TypedDict): @@ -109,120 +56,179 @@ class GLMImagePixelInputs(TypedDict): """Shape: `(batch_size, num_channels, height, width)`""" -def get_max_glmv_image_tokens(ctx: InputContext): - hf_config = ctx.get_hf_config(ChatGLMConfig) +class GLM4VProcessor: + """ + This model doesn't define its own HF processor, + so we implement our own one here. - vision_config = getattr(hf_config, 'vision_config', None) - if vision_config is None: - return 1 - elif isinstance(vision_config, dict): - return calculate_image_placeholder(vision_config) + """ - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + def __init__( + self, + config: ChatGLMConfig, + tokenizer: PreTrainedTokenizer, + ) -> None: + super().__init__() + self.config = config + self.tokenizer = tokenizer + + if vision_config := getattr(config, "vision_config", None): + image_size = vision_config["image_size"] + + self.image_transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ]) + else: + self.image_transform = None -def dummy_data_for_glmv(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]) -> DummyData: - hf_config = ctx.get_hf_config(ChatGLMConfig) - vision_config = getattr(hf_config, 'vision_config', None) + def __call__( + self, + text: Optional[Union[TextInput, list[TextInput]]] = None, + images: Optional[Union[ImageInput, list[ImageInput]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + text_inputs = self.tokenizer(text) + if len(images) == 0: + image_inputs = {} + else: + if self.image_transform is None: + raise ValueError("This model does not support image inputs") + + pixel_values = [self.image_transform(image) for image in images] + image_inputs = {"pixel_values": torch.stack(pixel_values)} + + return BatchFeature( + { + **text_inputs, + **image_inputs, + }, + tensor_type=return_tensors, + ) - if vision_config is None: - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len) - seq_data = SequenceData(token_ids) - return DummyData(seq_data, None) - elif isinstance(vision_config, dict): - image_size = vision_config["image_size"] - image_placeholder_length = calculate_image_placeholder(vision_config) - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] + - [0] * image_placeholder_length + - [hf_config.eoi_token_id]) - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0] * (seq_len - image_placeholder_length - 2)) - seq_data = SequenceData(token_ids) - mm_data = { - "image": Image.new("RGB", (image_size, image_size), color=0) - } +class GLM4VProcessingInfo(BaseProcessingInfo): - return DummyData(seq_data, mm_data) + def get_tokenizer(self): + tokenizer = self.ctx.tokenizer + assert isinstance(tokenizer, PreTrainedTokenizer) + return tokenizer - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + def get_hf_config(self): + return self.ctx.get_hf_config(ChatGLMConfig) + def get_hf_processor(self) -> GLM4VProcessor: + return GLM4VProcessor( + self.get_hf_config(), + self.get_tokenizer(), + ) -def find_all_positions(input_ids: List[int], target: int) -> List[int]: - return [index for index, value in enumerate(input_ids) if value == target] + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_num_image_feature_tokens()} -def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs + def get_num_image_tokens(self) -> int: + hf_config = self.get_hf_config() + if not (vision_config := getattr(hf_config, "vision_config", None)): + return 0 - hf_config = ctx.get_hf_config(ChatGLMConfig) - vision_config = getattr(hf_config, 'vision_config', None) + image_size = vision_config["image_size"] + patch_size = vision_config["patch_size"] + grid_length = image_size // patch_size // 2 + return grid_length * grid_length - if vision_config is None: - return inputs - elif isinstance(vision_config, dict): - image_placeholder_length = calculate_image_placeholder(vision_config) - else: - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + def get_num_image_feature_tokens(self) -> int: + # EVA2CLIPModel has embeddings for boi and eoi tokens as well + return self.get_num_image_tokens() + 2 - input_ids = inputs["prompt_token_ids"] - tokenizer = cached_get_tokenizer( - ctx.model_config.model, - trust_remote_code=ctx.model_config.trust_remote_code) +class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): - try: - raw_batch_data = tokenizer.apply_chat_template( - conversation=[{ - "role": "user", - "image": multi_modal_data["image"], - "content": inputs['prompt'], - }], - add_generation_prompt=True, - tokenize=True, - return_tensors="pt", - return_dict=True, - ).data - except Exception: - logger.error("Failed to process content (%s)", inputs['prompt']) - raise - input_ids = raw_batch_data['input_ids'][0].tolist() + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + hf_config = self.info.get_hf_config() + if not (vision_config := getattr(hf_config, "vision_config", None)): + return ProcessorInputs(prompt_text="", mm_data={}) - boi_token_id = hf_config.boi_token_id - eoi_token_id = hf_config.eoi_token_id - boi_positions = find_all_positions(input_ids, boi_token_id) - eoi_positions = find_all_positions(input_ids, eoi_token_id) + target_width = target_height = vision_config["image_size"] + num_images = mm_counts.get("image", 0) - assert len(boi_positions) == len(eoi_positions) + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } - new_input_ids = [] - final_processed_position = 0 + base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>" - for boi_position, eoi_position in zip(boi_positions, eoi_positions): - assert boi_position < eoi_position - new_input_ids.extend(input_ids[final_processed_position:boi_position + - 1]) - new_input_ids.extend([input_ids[boi_position + 1]] * - image_placeholder_length) - final_processed_position = eoi_position + return ProcessorInputs( + prompt_text=base_text * num_images, + mm_data=mm_data, + ) - new_input_ids.extend(input_ids[final_processed_position:]) - prompt = inputs.get("prompt") - if prompt is None: - prompt = tokenizer.decode(new_input_ids) +class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): - return token_inputs( - prompt_token_ids=new_input_ids, - prompt=prompt, - multi_modal_data=multi_modal_data, - ) + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + hf_config = self.info.get_hf_config() + if not hasattr(hf_config, "vision_config"): + return [] + + boi_token_id = hf_config.boi_token_id + image_token_id = hf_config.pad_token_id + eoi_token_id = hf_config.eoi_token_id + + def get_replacement(item_idx: int): + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [image_token_id] * num_image_tokens + + return [boi_token_id] + image_tokens + [eoi_token_id] + + return [ + PromptReplacement( + modality="image", + target=[boi_token_id, image_token_id, eoi_token_id], + replacement=get_replacement, + ), + ] class GLMAttention(nn.Module): @@ -572,12 +578,16 @@ def get_input_embeddings( ) -> torch.Tensor: inputs_embeds = self.embedding(input_ids) if multimodal_embeddings is not None: - inputs_embeds = merge_glm_vision_embeddings( + inputs_embeds = merge_multimodal_embeddings( input_ids=input_ids, inputs_embeds=inputs_embeds, - vision_embeddings=multimodal_embeddings, - boi_token_id=self.config.boi_token_id, - eoi_token_id=self.config.eoi_token_id) + multimodal_embeddings=multimodal_embeddings, + placeholder_token_id=[ + self.config.boi_token_id, + self.config.pad_token_id, + self.config.eoi_token_id, + ], + ) return inputs_embeds def forward( @@ -593,14 +603,12 @@ def forward( # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. - if intermediate_tensors is None and inputs_embeds is None: + if intermediate_tensors is not None: + inputs_embeds = intermediate_tensors["hidden_states"] + elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) - input_ids = None - else: - inputs_embeds = intermediate_tensors["hidden_states"] - # Run encoder. hidden_states = self.encoder( hidden_states=inputs_embeds, @@ -763,11 +771,21 @@ def get_mm_mapping(self) -> MultiModelKeys: connector="transformer.vision.linear_proj", tower_model="transformer.vision.transformer") + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + return self.transformer.get_multimodal_embeddings(**kwargs) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids, + multimodal_embeddings) + -@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv) -@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv) +@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor, + info=GLM4VProcessingInfo, + dummy_inputs=GLM4VDummyInputsBuilder) class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, SupportsMultiModal): # Ensure that the LoRA support check passes when the class is not diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 0c6f07ce7b112..fd0e58fa1458d 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -581,7 +581,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - ) + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") else: self.embed_tokens = PPMissingLayer() diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index d82c0815213bc..f307f279dad4d 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -455,14 +455,9 @@ def forward(self, self.mamba_cache = MambaCacheManager( self.lm_head.weight.dtype, num_mamba_layers, self.max_batch_size, *self._get_mamba_cache_shape()) - ( - mamba_cache_tensors, - state_indices_tensor, - ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, - **kwargs) - mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], - mamba_cache_tensors[1], - state_indices_tensor) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, mamba_cache_params, intermediate_tensors, inputs_embeds) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d91c8782a121c..2ff52dd789125 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -128,6 +128,9 @@ def __init__(self, # MistralConfig has an optional head_dim introduced by Mistral-Nemo self.head_dim = getattr(config, "head_dim", self.hidden_size // self.total_num_heads) + # Phi models introduced a partial_rotary_factor parameter in the config + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) + self.rotary_dim = int(partial_rotary_factor * self.head_dim) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -159,7 +162,7 @@ def __init__(self, self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, + rotary_dim=self.rotary_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, @@ -464,6 +467,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): mistral_mapping = { "layers": "model.layers", "attention": "self_attn", + "qscale_act": "input_scale", + "qscale_weight": "weight_scale", + "kv_fake_quantizer.qscale_act": "kv_scale", "wq": "q_proj", "wk": "k_proj", "wv": "v_proj", @@ -587,15 +593,24 @@ def permute(w: torch.Tensor, n_heads: int): modules = name.split(".") # rotary embeds should be sliced - if "wk" in modules: + if "wk" in modules and modules[-1] == "weight": loaded_weight = permute(loaded_weight, self.config.num_key_value_heads) - elif "wq" in modules: + elif "wq" in modules and modules[-1] == "weight": loaded_weight = permute(loaded_weight, self.config.num_attention_heads) - for item in modules: - if item in mapping and mapping[item] not in name: + num_modules = len(modules) + for i in range(num_modules): + item = modules[i] + next_item = modules[i + 1] if i < num_modules - 1 else None + + combined_item = (f"{item}.{next_item}" + if next_item is not None else None) + + if combined_item in mapping: + name = name.replace(combined_item, mapping[combined_item]) + elif item in mapping and mapping[item] not in name: name = name.replace(item, mapping[item]) return name, loaded_weight diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 5034b334564e8..3bbc219e92a65 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -232,15 +232,7 @@ def forward(self, self.lm_head.weight.dtype, num_mamba_layers, self.max_batch_size, *self._get_mamba_cache_shape()) - ( - mamba_cache_tensors, - state_indices_tensor, - ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, - **kwargs) - - mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], - mamba_cache_tensors[1], - state_indices_tensor) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.backbone(input_ids, positions, attn_metadata, mamba_cache_params, intermediate_tensors, diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 353177f784b2e..ce4197507da03 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -5,7 +5,6 @@ import torch -from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.utils import PAD_SLOT_ID @@ -42,8 +41,7 @@ def __init__(self, dtype, num_mamba_layers, max_batch_size, self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.free_cache_indices = list(range(max_batch_size)) - def current_run_tensors(self, input_ids: torch.Tensor, - attn_metadata: AttentionMetadata, **kwargs): + def current_run_tensors(self, **kwargs) -> MambaCacheParams: """ Return the tensors for the current run's conv and ssm state. """ @@ -66,7 +64,8 @@ def current_run_tensors(self, input_ids: torch.Tensor, (mamba_cache_tensors, state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"] - return (mamba_cache_tensors, state_indices_tensor) + return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1], + state_indices_tensor) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 003e9c84c1c0a..e78e8d62cc47c 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -54,8 +54,11 @@ def get_max_pixtral_image_tokens(ctx: InputContext): tokenizer_mode=ctx.model_config.tokenizer_mode) mm_encoder = tokenizer.instruct.mm_encoder - max_image_size = mm_encoder.mm_config.max_image_size - image_patch_size = mm_encoder.mm_config.image_patch_size + image_config = mm_encoder.mm_config if hasattr( + mm_encoder, "mm_config") else mm_encoder.image_config + + max_image_size = image_config.max_image_size + image_patch_size = image_config.image_patch_size return ((max_image_size // image_patch_size)**2) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 8970661243148..4b8aeaddbdd37 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -63,18 +63,6 @@ logger = init_logger(__name__) -# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad; -# for the time being, these tags are not considered as special at encoding -# time. This may change as VLLMs multimodal API changes in the future. -IMG_START = "" -IMG_END = "" -IMG_PAD = "" -# Image context is fixed at 256 for all images -MAX_QWEN_IMG_TOKENS = 256 -# Image normalization params -CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) -CLIP_STD = (0.26862954, 0.26130258, 0.27577711) - class QwenImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -622,25 +610,6 @@ def forward( return hidden_states -def build_normalization_transform(image_size: int) -> transforms.Compose: - """ - Build a normalization transform which can be applied to one or - more input images from which we want to extract visual features. - - Args: - image_size: size of the image to be processed for visual embeddings. - - Returns: - Callable transform for normalizing and resizing one RGB image. - """ - return transforms.Compose([ - transforms.Resize((image_size, image_size), - interpolation=InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), - ]) - - @lru_cache(maxsize=1) def _get_tokenizer_without_image_pad( tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: @@ -716,16 +685,34 @@ def __init__( self.config = config self.tokenizer = tokenizer - if hasattr(self.config, "visual"): - self.image_transform = build_normalization_transform( - config.visual["image_size"]) + if vision_config := getattr(self.config, "visual", None): + image_size = vision_config["image_size"] + + self.image_transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ]) else: self.image_transform = None - special_tokens: dict[str, - int] = tokenizer.special_tokens # type: ignore - self.img_start_id = special_tokens[IMG_START] - self.img_end_id = special_tokens[IMG_END] + @property + def image_start_tag(self) -> str: + return self.tokenizer.image_start_tag # type: ignore + + @property + def image_end_tag(self) -> str: + return self.tokenizer.image_end_tag # type: ignore + + @property + def image_pad_tag(self) -> str: + return self.tokenizer.image_pad_tag # type: ignore def __call__( self, @@ -787,7 +774,14 @@ def get_mm_max_tokens_per_item( return {"image": self.get_num_image_tokens()} def get_num_image_tokens(self) -> int: - return MAX_QWEN_IMG_TOKENS + hf_config = self.get_hf_config() + if not (vision_config := getattr(hf_config, "visual", None)): + return 0 + + image_size = vision_config["image_size"] + patch_size = vision_config["patch_size"] + grid_length = image_size // patch_size // 2 + return grid_length * grid_length class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]): @@ -798,10 +792,12 @@ def get_dummy_processor_inputs( mm_counts: Mapping[str, int], ) -> ProcessorInputs: hf_config = self.info.get_hf_config() - if not hasattr(hf_config, "visual"): + if not (vision_config := getattr(hf_config, "visual", None)): return ProcessorInputs(prompt_text="", mm_data={}) - vision_config = hf_config.visual + processor = self.info.get_hf_processor() + img_start = processor.image_start_tag + img_end = processor.image_end_tag target_width = target_height = vision_config["image_size"] num_images = mm_counts.get("image", 0) @@ -814,7 +810,7 @@ def get_dummy_processor_inputs( } return ProcessorInputs( - prompt_text="".join(f"Picture {i}: {IMG_START}{IMG_END}\n" + prompt_text="".join(f"Picture {i}: {img_start}{img_end}\n" for i in range(1, num_images + 1)), mm_data=mm_data, ) @@ -869,13 +865,18 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: + hf_config = self.info.get_hf_config() + if not hasattr(hf_config, "visual"): + return [] + tokenizer = self.info.get_tokenizer() special_tokens: dict[str, int] = tokenizer.special_tokens # type: ignore - img_start_id = special_tokens[IMG_START] - img_end_id = special_tokens[IMG_END] - img_pad_id = special_tokens[IMG_PAD] + processor = self.info.get_hf_processor() + img_start_id = special_tokens[processor.image_start_tag] + img_end_id = special_tokens[processor.image_end_tag] + img_pad_id = special_tokens[processor.image_pad_tag] num_image_tokens = self.info.get_num_image_tokens() image_tokens = [img_pad_id] * num_image_tokens diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index e93cf46b900b6..d4c48dbdab13c 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -40,7 +40,7 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.distributed import parallel_state +from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -207,11 +207,12 @@ def __init__( ) -> None: super().__init__() # Per attention head and per partition values. - world_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, world_size) + num_heads, self.tp_size) self.qkv = ColumnParallelLinear(input_size=embed_dim, output_size=3 * projection_size, @@ -231,6 +232,29 @@ def __init__( f"Qwen2.5-VL does not support {self.attn_backend} backend now." ) + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # [s, b, 3 * head * head_dim] + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = tensor_model_parallel_all_gather(qkv) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] + q, k, v = qkv.chunk(3, dim=2) + + # 3 * [s, b, head * head_dim] + if self.tp_size > 1: + splitter = partial(dist_utils.split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] + new_shape = (seq_len, bs, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + def forward( self, x: torch.Tensor, @@ -240,15 +264,8 @@ def forward( # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) - # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - x = x.view(*new_x_shape) - - # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim] - q, k, v = dist_utils.split_tensor_along_last_dim(x, 3) + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) batch_size = q.shape[1] q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() @@ -665,24 +682,6 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight, shard_id) break else: - if name.endswith("qkv.weight"): - visual_num_heads = self.num_heads - visual_embed_dim = self.hidden_size - head_size = visual_embed_dim // visual_num_heads - loaded_weight = loaded_weight.view(3, visual_num_heads, - head_size, - visual_embed_dim) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) - elif name.endswith("qkv.bias"): - visual_num_heads = self.num_heads - visual_embed_dim = self.hidden_size - head_size = visual_embed_dim // visual_num_heads - loaded_weight = loaded_weight.view(3, visual_num_heads, - head_size) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1) - param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -760,9 +759,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, "q_proj", "k_proj", "v_proj", - ] + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], } - # LoRA specific attributes, TODO: double check supported_lora_modules = [ "qkv_proj", diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 34ae7b8c94697..f2071eaff481f 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -885,14 +885,10 @@ def get_num_frames_with_most_features(self, seq_len: int) -> int: max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) - num_frames = min(max(max_total_frames // max(max_videos, 1), 1), - _MAX_FRAMES_PER_VIDEO) + max_frames_per_video = min(max_total_frames // max(max_videos, 1), + _MAX_FRAMES_PER_VIDEO) - # Temporary workaround for https://github.com/huggingface/transformers/issues/35412 - if num_frames > 1 and num_frames % 2 == 1: - num_frames += 1 - - return num_frames + return max(max_frames_per_video, 1) def get_max_video_tokens(self, seq_len: int) -> int: target_width, target_height = self.get_image_size_with_most_features() diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 3b2a7069efc91..c2d0fae7056c7 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -37,6 +37,7 @@ "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-13b, lower case 'c' in the class name "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), + "BambaForCausalLM": ("bamba", "BambaForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"), # ChatGLMModel supports multimodal "CohereForCausalLM": ("commandr", "CohereForCausalLM"), diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 9da0682cfa866..063997a14a66f 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -258,27 +258,35 @@ def __init__(self, config: UltravoxConfig): super().__init__() self.hidden_dim = config.hidden_size self._pad_and_stack = StackAudioFrames(config.stack_factor) - dim = config.audio_config.hidden_size * config.stack_factor - self.ln_pre = RMSNorm(dim) - self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False) - dim = self.hidden_dim + dim_in = config.audio_config.hidden_size * config.stack_factor + self.ln_pre = RMSNorm(dim_in) + self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False) + dim_mid = self.hidden_dim if config.projector_act == "swiglu": self.act = MulAndSilu() - dim = dim // 2 + dim_mid = dim_mid // 2 else: self.act = get_act_fn(config.projector_act) - self.linear_2 = nn.Linear(dim, - config.text_config.hidden_size, - bias=False) - self.ln_post = RMSNorm(config.text_config.hidden_size) + dim_out = config.text_config.hidden_size + self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False) + + # Ultravox v0.4.1 and below use layer_norm after the second linear layer + # while v0.5.0 and above uses layer_norm after the first linear layer. + if config.projector_ln_mid: + self.ln_mid: nn.Module = RMSNorm(dim_mid) + self.ln_post = nn.Identity() + else: + self.ln_mid = nn.Identity() + self.ln_post = RMSNorm(dim_out) def forward(self, audio_features: torch.Tensor) -> torch.Tensor: audio_features = self._pad_and_stack(audio_features) audio_features = self.ln_pre(audio_features) hidden_states = self.linear_1(audio_features) hidden_states = self.act(hidden_states) + hidden_states = self.ln_mid(hidden_states) hidden_states = self.linear_2(hidden_states) hidden_states = self.ln_post(hidden_states) return hidden_states diff --git a/vllm/outputs.py b/vllm/outputs.py index 786380c37f6cb..030119710a187 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -142,6 +142,9 @@ def new( prompt_token_ids: Optional[List[int]], text: str, token_ids: List[int], + logprobs: Optional[SampleLogprobs], + prompt_logprobs: Optional[PromptLogprobs], + cumulative_logprob: Optional[float], finished: bool = False, ) -> "RequestOutput": """Initialize a new RequestOutput object.""" @@ -151,15 +154,14 @@ def new( index=0, text=text, token_ids=token_ids, - cumulative_logprob=None, - logprobs=None, # TODO - ) + cumulative_logprob=cumulative_logprob, + logprobs=logprobs) return RequestOutput( request_id=request_id, prompt=prompt, prompt_token_ids=prompt_token_ids, - prompt_logprobs=None, # TODO + prompt_logprobs=prompt_logprobs, outputs=[completion_output], finished=finished, ) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 4e0683b8a2de1..179ee6a7d2478 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -35,7 +35,7 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool) -> str: - if selected_backend != _Backend.TORCH_SDPA: + if selected_backend and selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Using Torch SDPA backend.") return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 991d55ac861a4..9deb0294668ec 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -334,10 +334,10 @@ def log_warnings(cls): if (len(set(device_names)) > 1 and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"): logger.warning( - "Detected different devices in the system: \n%s\nPlease" + "Detected different devices in the system: %s. Please" " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to " "avoid unexpected behavior.", - "\n".join(device_names), + ", ".join(device_names), ) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 211e288b125da..645d98a1bb42c 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -211,16 +211,17 @@ def inference_mode(cls): return torch.inference_mode(mode=True) @classmethod - def seed_everything(cls, seed: int) -> None: + def seed_everything(cls, seed: Optional[int] = None) -> None: """ Set the seed of each random module. `torch.manual_seed` will set seed on all devices. Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 """ - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) + if seed is not None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 035766289aebd..1f690b7111ee2 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -72,7 +72,7 @@ class RocmPlatform(Platform): supported_quantization: list[str] = [ "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", - "fbgemm_fp8", "gguf", "quark" + "fbgemm_fp8", "gguf", "quark", "ptpc_fp8" ] @classmethod diff --git a/vllm/third_party/__init__.py b/vllm/third_party/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/third_party/pynvml.py b/vllm/third_party/pynvml.py new file mode 100644 index 0000000000000..0a4be23a09362 --- /dev/null +++ b/vllm/third_party/pynvml.py @@ -0,0 +1,6139 @@ +# SPDX-License-Identifier: Apache-2.0 +# copied from https://pypi.org/project/nvidia-ml-py +# version 12.570.86 + +##### +# Copyright (c) 2011-2023, NVIDIA Corporation. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA Corporation nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +##### + +## +# Python bindings for the NVML library +## +from ctypes import * +from ctypes.util import find_library +from functools import wraps +import sys +import os +import threading +import string + +## C Type mappings ## +## Enums +_nvmlEnableState_t = c_uint +NVML_FEATURE_DISABLED = 0 +NVML_FEATURE_ENABLED = 1 + +_nvmlBrandType_t = c_uint +NVML_BRAND_UNKNOWN = 0 +NVML_BRAND_QUADRO = 1 +NVML_BRAND_TESLA = 2 +NVML_BRAND_NVS = 3 +NVML_BRAND_GRID = 4 # Deprecated from API reporting. Keeping definition for backward compatibility. +NVML_BRAND_GEFORCE = 5 +NVML_BRAND_TITAN = 6 +NVML_BRAND_NVIDIA_VAPPS = 7 # NVIDIA Virtual Applications +NVML_BRAND_NVIDIA_VPC = 8 # NVIDIA Virtual PC +NVML_BRAND_NVIDIA_VCS = 9 # NVIDIA Virtual Compute Server +NVML_BRAND_NVIDIA_VWS = 10 # NVIDIA RTX Virtual Workstation +NVML_BRAND_NVIDIA_CLOUD_GAMING = 11 # NVIDIA Cloud Gaming +NVML_BRAND_NVIDIA_VGAMING = NVML_BRAND_NVIDIA_CLOUD_GAMING # Deprecated from API reporting. Keeping definition for backward compatibility. +NVML_BRAND_QUADRO_RTX = 12 +NVML_BRAND_NVIDIA_RTX = 13 +NVML_BRAND_NVIDIA = 14 +NVML_BRAND_GEFORCE_RTX = 15 # Unused +NVML_BRAND_TITAN_RTX = 16 # Unused +NVML_BRAND_COUNT = 17 + +_nvmlTemperatureThresholds_t = c_uint +NVML_TEMPERATURE_THRESHOLD_SHUTDOWN = 0 +NVML_TEMPERATURE_THRESHOLD_SLOWDOWN = 1 +NVML_TEMPERATURE_THRESHOLD_MEM_MAX = 2 +NVML_TEMPERATURE_THRESHOLD_GPU_MAX = 3 +NVML_TEMPERATURE_THRESHOLD_ACOUSTIC_MIN = 4 +NVML_TEMPERATURE_THRESHOLD_ACOUSTIC_CURR = 5 +NVML_TEMPERATURE_THRESHOLD_ACOUSTIC_MAX = 6 +NVML_TEMPERATURE_THRESHOLD_GPS_CURR = 7 +NVML_TEMPERATURE_THRESHOLD_COUNT = 8 + +_nvmlTemperatureSensors_t = c_uint +NVML_TEMPERATURE_GPU = 0 +NVML_TEMPERATURE_COUNT = 1 + + +_nvmlComputeMode_t = c_uint +NVML_COMPUTEMODE_DEFAULT = 0 +NVML_COMPUTEMODE_EXCLUSIVE_THREAD = 1 ## Support Removed +NVML_COMPUTEMODE_PROHIBITED = 2 +NVML_COMPUTEMODE_EXCLUSIVE_PROCESS = 3 +NVML_COMPUTEMODE_COUNT = 4 + +_nvmlMemoryLocation_t = c_uint +NVML_MEMORY_LOCATION_L1_CACHE = 0 +NVML_MEMORY_LOCATION_L2_CACHE = 1 +NVML_MEMORY_LOCATION_DEVICE_MEMORY = 2 +NVML_MEMORY_LOCATION_DRAM = 2 +NVML_MEMORY_LOCATION_REGISTER_FILE = 3 +NVML_MEMORY_LOCATION_TEXTURE_MEMORY = 4 +NVML_MEMORY_LOCATION_TEXTURE_SHM = 5 +NVML_MEMORY_LOCATION_CBU = 6 +NVML_MEMORY_LOCATION_SRAM = 7 +NVML_MEMORY_LOCATION_COUNT = 8 + +NVML_NVLINK_MAX_LINKS = 18 + +# For backwards compatibility, maintain the incorrectly-named "LANES" define +NVML_NVLINK_MAX_LANES = NVML_NVLINK_MAX_LINKS + +_nvmlNvLinkErrorCounter_t = c_uint +NVML_NVLINK_ERROR_DL_REPLAY = 0 +NVML_NVLINK_ERROR_DL_RECOVERY = 1 +NVML_NVLINK_ERROR_DL_CRC_FLIT = 2 +NVML_NVLINK_ERROR_DL_CRC_DATA = 3 +NVML_NVLINK_ERROR_DL_ECC_DATA = 4 +NVML_NVLINK_ERROR_COUNT = 5 + +_nvmlNvLinkEccLaneErrorCounter_t = c_uint +NVML_NVLINK_ERROR_DL_ECC_LANE0 = 0 +NVML_NVLINK_ERROR_DL_ECC_LANE1 = 1 +NVML_NVLINK_ERROR_DL_ECC_LANE2 = 2 +NVML_NVLINK_ERROR_DL_ECC_LANE3 = 3 +NVML_NVLINK_ERROR_DL_ECC_COUNT = 5 + +_nvmlNvLinkCapability_t = c_uint +NVML_NVLINK_CAP_P2P_SUPPORTED = 0 +NVML_NVLINK_CAP_SYSMEM_ACCESS = 1 +NVML_NVLINK_CAP_P2P_ATOMICS = 2 +NVML_NVLINK_CAP_SYSMEM_ATOMICS= 3 +NVML_NVLINK_CAP_SLI_BRIDGE = 4 +NVML_NVLINK_CAP_VALID = 5 +NVML_NVLINK_CAP_COUNT = 6 + +_nvmlNvLinkUtilizationCountPktTypes_t = c_uint +NVML_NVLINK_COUNTER_PKTFILTER_NOP = 0x1 +NVML_NVLINK_COUNTER_PKTFILTER_READ = 0x2 +NVML_NVLINK_COUNTER_PKTFILTER_WRITE = 0x4 +NVML_NVLINK_COUNTER_PKTFILTER_RATOM = 0x8 +NVML_NVLINK_COUNTER_PKTFILTER_NRATOM = 0x10 +NVML_NVLINK_COUNTER_PKTFILTER_FLUSH = 0x20 +NVML_NVLINK_COUNTER_PKTFILTER_RESPDATA = 0x40 +NVML_NVLINK_COUNTER_PKTFILTER_RESPNODATA = 0x80 +NVML_NVLINK_COUNTER_PKTFILTER_ALL = 0xFF + +_nvmlNvLinkUtilizationCountUnits_t = c_uint +NVML_NVLINK_COUNTER_UNIT_CYCLES = 0 +NVML_NVLINK_COUNTER_UNIT_PACKETS = 1 +NVML_NVLINK_COUNTER_UNIT_BYTES = 2 +NVML_NVLINK_COUNTER_UNIT_RESERVED = 3 +NVML_NVLINK_COUNTER_UNIT_COUNT = 4 + +_nvmlNvLinkDeviceType_t = c_uint +NVML_NVLINK_DEVICE_TYPE_GPU = 0x00 +NVML_NVLINK_DEVICE_TYPE_IBMNPU = 0x01 +NVML_NVLINK_DEVICE_TYPE_SWITCH = 0x02 +NVML_NVLINK_DEVICE_TYPE_UNKNOWN = 0xFF + +# These are deprecated, instead use _nvmlMemoryErrorType_t +_nvmlEccBitType_t = c_uint +NVML_SINGLE_BIT_ECC = 0 +NVML_DOUBLE_BIT_ECC = 1 +NVML_ECC_ERROR_TYPE_COUNT = 2 + +_nvmlEccCounterType_t = c_uint +NVML_VOLATILE_ECC = 0 +NVML_AGGREGATE_ECC = 1 +NVML_ECC_COUNTER_TYPE_COUNT = 2 + +_nvmlMemoryErrorType_t = c_uint +NVML_MEMORY_ERROR_TYPE_CORRECTED = 0 +NVML_MEMORY_ERROR_TYPE_UNCORRECTED = 1 +NVML_MEMORY_ERROR_TYPE_COUNT = 2 + +_nvmlClockType_t = c_uint +NVML_CLOCK_GRAPHICS = 0 +NVML_CLOCK_SM = 1 +NVML_CLOCK_MEM = 2 +NVML_CLOCK_VIDEO = 3 +NVML_CLOCK_COUNT = 4 + +_nvmlClockId_t = c_uint +NVML_CLOCK_ID_CURRENT = 0 +NVML_CLOCK_ID_APP_CLOCK_TARGET = 1 +NVML_CLOCK_ID_APP_CLOCK_DEFAULT = 2 +NVML_CLOCK_ID_CUSTOMER_BOOST_MAX = 3 +NVML_CLOCK_ID_COUNT = 4 + +_nvmlDriverModel_t = c_uint +NVML_DRIVER_WDDM = 0 +NVML_DRIVER_WDM = 1 +NVML_DRIVER_MCDM = 2 + +NVML_MAX_GPU_PERF_PSTATES = 16 + +_nvmlPstates_t = c_uint +NVML_PSTATE_0 = 0 +NVML_PSTATE_1 = 1 +NVML_PSTATE_2 = 2 +NVML_PSTATE_3 = 3 +NVML_PSTATE_4 = 4 +NVML_PSTATE_5 = 5 +NVML_PSTATE_6 = 6 +NVML_PSTATE_7 = 7 +NVML_PSTATE_8 = 8 +NVML_PSTATE_9 = 9 +NVML_PSTATE_10 = 10 +NVML_PSTATE_11 = 11 +NVML_PSTATE_12 = 12 +NVML_PSTATE_13 = 13 +NVML_PSTATE_14 = 14 +NVML_PSTATE_15 = 15 +NVML_PSTATE_UNKNOWN = 32 + +_nvmlInforomObject_t = c_uint +NVML_INFOROM_OEM = 0 +NVML_INFOROM_ECC = 1 +NVML_INFOROM_POWER = 2 +NVML_INFOROM_DEN = 3 +NVML_INFOROM_COUNT = 4 + +_nvmlReturn_t = c_uint +NVML_SUCCESS = 0 +NVML_ERROR_UNINITIALIZED = 1 +NVML_ERROR_INVALID_ARGUMENT = 2 +NVML_ERROR_NOT_SUPPORTED = 3 +NVML_ERROR_NO_PERMISSION = 4 +NVML_ERROR_ALREADY_INITIALIZED = 5 +NVML_ERROR_NOT_FOUND = 6 +NVML_ERROR_INSUFFICIENT_SIZE = 7 +NVML_ERROR_INSUFFICIENT_POWER = 8 +NVML_ERROR_DRIVER_NOT_LOADED = 9 +NVML_ERROR_TIMEOUT = 10 +NVML_ERROR_IRQ_ISSUE = 11 +NVML_ERROR_LIBRARY_NOT_FOUND = 12 +NVML_ERROR_FUNCTION_NOT_FOUND = 13 +NVML_ERROR_CORRUPTED_INFOROM = 14 +NVML_ERROR_GPU_IS_LOST = 15 +NVML_ERROR_RESET_REQUIRED = 16 +NVML_ERROR_OPERATING_SYSTEM = 17 +NVML_ERROR_LIB_RM_VERSION_MISMATCH = 18 +NVML_ERROR_IN_USE = 19 +NVML_ERROR_MEMORY = 20 +NVML_ERROR_NO_DATA = 21 +NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22 +NVML_ERROR_INSUFFICIENT_RESOURCES = 23 +NVML_ERROR_FREQ_NOT_SUPPORTED = 24 +NVML_ERROR_ARGUMENT_VERSION_MISMATCH = 25 +NVML_ERROR_DEPRECATED = 26 +NVML_ERROR_NOT_READY = 27 +NVML_ERROR_GPU_NOT_FOUND = 28 +NVML_ERROR_INVALID_STATE = 29 +NVML_ERROR_UNKNOWN = 999 + +_nvmlFanState_t = c_uint +NVML_FAN_NORMAL = 0 +NVML_FAN_FAILED = 1 + +_nvmlFanControlPolicy_t = c_uint +NVML_FAN_POLICY_TEMPERATURE_CONTINOUS_SW = 0 +NVML_FAN_POLICY_MANUAL = 1 + +_nvmlLedColor_t = c_uint +NVML_LED_COLOR_GREEN = 0 +NVML_LED_COLOR_AMBER = 1 + +_nvmlGpuOperationMode_t = c_uint +NVML_GOM_ALL_ON = 0 +NVML_GOM_COMPUTE = 1 +NVML_GOM_LOW_DP = 2 + +_nvmlPageRetirementCause_t = c_uint +NVML_PAGE_RETIREMENT_CAUSE_MULTIPLE_SINGLE_BIT_ECC_ERRORS = 0 +NVML_PAGE_RETIREMENT_CAUSE_DOUBLE_BIT_ECC_ERROR = 1 +NVML_PAGE_RETIREMENT_CAUSE_COUNT = 2 + +_nvmlRestrictedAPI_t = c_uint +NVML_RESTRICTED_API_SET_APPLICATION_CLOCKS = 0 +NVML_RESTRICTED_API_SET_AUTO_BOOSTED_CLOCKS = 1 +NVML_RESTRICTED_API_COUNT = 2 + +_nvmlBridgeChipType_t = c_uint +NVML_BRIDGE_CHIP_PLX = 0 +NVML_BRIDGE_CHIP_BRO4 = 1 +NVML_MAX_PHYSICAL_BRIDGE = 128 + +_nvmlValueType_t = c_uint +NVML_VALUE_TYPE_DOUBLE = 0 +NVML_VALUE_TYPE_UNSIGNED_INT = 1 +NVML_VALUE_TYPE_UNSIGNED_LONG = 2 +NVML_VALUE_TYPE_UNSIGNED_LONG_LONG = 3 +NVML_VALUE_TYPE_SIGNED_LONG_LONG = 4 +NVML_VALUE_TYPE_SIGNED_INT = 5 +NVML_VALUE_TYPE_UNSIGNED_SHORT = 6 +NVML_VALUE_TYPE_COUNT = 7 + +_nvmlNvlinkVersion_t = c_uint +NVML_NVLINK_VERSION_INVALID = 0 +NVML_NVLINK_VERSION_1_0 = 1 +NVML_NVLINK_VERSION_2_0 = 2 +NVML_NVLINK_VERSION_2_2 = 3 +NVML_NVLINK_VERSION_3_0 = 4 +NVML_NVLINK_VERSION_3_1 = 5 +NVML_NVLINK_VERSION_4_0 = 6 +NVML_NVLINK_VERSION_5_0 = 7 + +_nvmlPerfPolicyType_t = c_uint +NVML_PERF_POLICY_POWER = 0 +NVML_PERF_POLICY_THERMAL = 1 +NVML_PERF_POLICY_SYNC_BOOST = 2 +NVML_PERF_POLICY_BOARD_LIMIT = 3 +NVML_PERF_POLICY_LOW_UTILIZATION = 4 +NVML_PERF_POLICY_RELIABILITY = 5 +NVML_PERF_POLICY_TOTAL_APP_CLOCKS = 10 +NVML_PERF_POLICY_TOTAL_BASE_CLOCKS = 11 +NVML_PERF_POLICY_COUNT = 12 + +_nvmlEncoderQueryType_t = c_uint +NVML_ENCODER_QUERY_H264 = 0 +NVML_ENCODER_QUERY_HEVC = 1 +NVML_ENCODER_QUERY_AV1 = 2 +NVML_ENCODER_QUERY_UNKNOWN = 255 + +_nvmlFBCSessionType_t = c_uint +NVML_FBC_SESSION_TYPE_UNKNOWN = 0 +NVML_FBC_SESSION_TYPE_TOSYS = 1 +NVML_FBC_SESSION_TYPE_CUDA = 2 +NVML_FBC_SESSION_TYPE_VID = 3 +NVML_FBC_SESSION_TYPE_HWENC = 4 + +_nvmlDetachGpuState_t = c_uint +NVML_DETACH_GPU_KEEP = 0 +NVML_DETACH_GPU_REMOVE = 1 + +_nvmlPcieLinkState_t = c_uint +NVML_PCIE_LINK_KEEP = 0 +NVML_PCIE_LINK_SHUT_DOWN = 1 + +_nvmlSamplingType_t = c_uint +NVML_TOTAL_POWER_SAMPLES = 0 +NVML_GPU_UTILIZATION_SAMPLES = 1 +NVML_MEMORY_UTILIZATION_SAMPLES = 2 +NVML_ENC_UTILIZATION_SAMPLES = 3 +NVML_DEC_UTILIZATION_SAMPLES = 4 +NVML_PROCESSOR_CLK_SAMPLES = 5 +NVML_MEMORY_CLK_SAMPLES = 6 +NVML_MODULE_POWER_SAMPLES = 7 +NVML_JPG_UTILIZATION_SAMPLES = 8 +NVML_OFA_UTILIZATION_SAMPLES = 9 +NVML_SAMPLINGTYPE_COUNT = 10 + +_nvmlPcieUtilCounter_t = c_uint +NVML_PCIE_UTIL_TX_BYTES = 0 +NVML_PCIE_UTIL_RX_BYTES = 1 +NVML_PCIE_UTIL_COUNT = 2 + +_nvmlGpuTopologyLevel_t = c_uint +NVML_TOPOLOGY_INTERNAL = 0 +NVML_TOPOLOGY_SINGLE = 10 +NVML_TOPOLOGY_MULTIPLE = 20 +NVML_TOPOLOGY_HOSTBRIDGE = 30 +NVML_TOPOLOGY_NODE = 40 +NVML_TOPOLOGY_CPU = NVML_TOPOLOGY_NODE +NVML_TOPOLOGY_SYSTEM = 50 + +_nvmlGpuP2PCapsIndex_t = c_uint +NVML_P2P_CAPS_INDEX_READ = 0, +NVML_P2P_CAPS_INDEX_WRITE = 1 +NVML_P2P_CAPS_INDEX_NVLINK =2 +NVML_P2P_CAPS_INDEX_ATOMICS = 3 +# +# NVML_P2P_CAPS_INDEX_PROP is deprecated. +# Use NVML_P2P_CAPS_INDEX_PCI instead. +# +NVML_P2P_CAPS_INDEX_PROP = 4 +NVML_P2P_CAPS_INDEX_PCI = 4 +NVML_P2P_CAPS_INDEX_UNKNOWN = 5 + +_nvmlGpuP2PStatus_t = c_uint +NVML_P2P_STATUS_OK = 0 +NVML_P2P_STATUS_CHIPSET_NOT_SUPPORED = 1 +NVML_P2P_STATUS_CHIPSET_NOT_SUPPORTED = NVML_P2P_STATUS_CHIPSET_NOT_SUPPORED +NVML_P2P_STATUS_GPU_NOT_SUPPORTED = 2 +NVML_P2P_STATUS_IOH_TOPOLOGY_NOT_SUPPORTED =3 +NVML_P2P_STATUS_DISABLED_BY_REGKEY =4 +NVML_P2P_STATUS_NOT_SUPPORTED =5 +NVML_P2P_STATUS_UNKNOWN =6 + +_nvmlDeviceArchitecture_t = c_uint +NVML_DEVICE_ARCH_KEPLER = 2 +NVML_DEVICE_ARCH_MAXWELL = 3 +NVML_DEVICE_ARCH_PASCAL = 4 +NVML_DEVICE_ARCH_VOLTA = 5 +NVML_DEVICE_ARCH_TURING = 6 +NVML_DEVICE_ARCH_AMPERE = 7 +NVML_DEVICE_ARCH_ADA = 8 +NVML_DEVICE_ARCH_HOPPER = 9 +NVML_DEVICE_ARCH_BLACKWELL = 10 +NVML_DEVICE_ARCH_T23X = 11 +NVML_DEVICE_ARCH_UNKNOWN = 0xffffffff + +# PCI bus Types +_nvmlBusType_t = c_uint +NVML_BUS_TYPE_UNKNOWN = 0 +NVML_BUS_TYPE_PCI = 1 +NVML_BUS_TYPE_PCIE = 2 +NVML_BUS_TYPE_FPCI = 3 +NVML_BUS_TYPE_AGP = 4 + +_nvmlPowerSource_t = c_uint +NVML_POWER_SOURCE_AC = 0x00000000 +NVML_POWER_SOURCE_BATTERY = 0x00000001 +NVML_POWER_SOURCE_UNDERSIZED = 0x00000002 + +_nvmlAdaptiveClockInfoStatus_t = c_uint +NVML_ADAPTIVE_CLOCKING_INFO_STATUS_DISABLED = 0x00000000 +NVML_ADAPTIVE_CLOCKING_INFO_STATUS_ENABLED = 0x00000001 + +_nvmlClockLimitId_t = c_uint +NVML_CLOCK_LIMIT_ID_RANGE_START = 0xffffff00 +NVML_CLOCK_LIMIT_ID_TDP = 0xffffff01 +NVML_CLOCK_LIMIT_ID_UNLIMITED = 0xffffff02 + +_nvmlPcieLinkMaxSpeed_t = c_uint +NVML_PCIE_LINK_MAX_SPEED_INVALID = 0x00000000 +NVML_PCIE_LINK_MAX_SPEED_2500MBPS = 0x00000001 +NVML_PCIE_LINK_MAX_SPEED_5000MBPS = 0x00000002 +NVML_PCIE_LINK_MAX_SPEED_8000MBPS = 0x00000003 +NVML_PCIE_LINK_MAX_SPEED_16000MBPS = 0x00000004 +NVML_PCIE_LINK_MAX_SPEED_32000MBPS = 0x00000005 +NVML_PCIE_LINK_MAX_SPEED_64000MBPS = 0x00000006 + +_nvmlPcieAtomicsCapability_t = c_uint +NVML_PCIE_ATOMICS_CAP_FETCHADD32 = 0x01 +NVML_PCIE_ATOMICS_CAP_FETCHADD64 = 0x02 +NVML_PCIE_ATOMICS_CAP_SWAP32 = 0x04 +NVML_PCIE_ATOMICS_CAP_SWAP64 = 0x08 +NVML_PCIE_ATOMICS_CAP_CAS32 = 0x10 +NVML_PCIE_ATOMICS_CAP_CAS64 = 0x20 +NVML_PCIE_ATOMICS_CAP_CAS128 = 0x40 +NVML_PCIE_ATOMICS_OPS_MAX = 7 + +_nvmlAffinityScope_t = c_uint +NVML_AFFINITY_SCOPE_NODE = 0 +NVML_AFFINITY_SCOPE_SOCKET = 1 + +_nvmlDeviceGpuRecoveryAction_t = c_uint +NVML_GPU_RECOVERY_ACTION_NONE = 0 +NVML_GPU_RECOVERY_ACTION_GPU_RESET = 1 +NVML_GPU_RECOVERY_ACTION_NODE_REBOOT = 2 +NVML_GPU_RECOVERY_ACTION_DRAIN_P2P = 3 +NVML_GPU_RECOVERY_ACTION_DRAIN_AND_RESET = 4 + +# C preprocessor defined values +nvmlFlagDefault = 0 +nvmlFlagForce = 1 +NVML_INIT_FLAG_NO_GPUS = 1 +NVML_INIT_FLAG_NO_ATTACH = 2 + +NVML_MAX_GPC_COUNT = 32 + +# buffer size +NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE = 16 +NVML_DEVICE_UUID_BUFFER_SIZE = 80 +NVML_DEVICE_UUID_V2_BUFFER_SIZE = 96 +NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE = 80 +NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE = 80 +NVML_DEVICE_NAME_BUFFER_SIZE = 64 +NVML_DEVICE_NAME_V2_BUFFER_SIZE = 96 +NVML_DEVICE_SERIAL_BUFFER_SIZE = 30 +NVML_DEVICE_PART_NUMBER_BUFFER_SIZE = 80 +NVML_DEVICE_GPU_PART_NUMBER_BUFFER_SIZE = 80 +NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE = 32 +NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE = 32 +NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE = 16 +NVML_GRID_LICENSE_BUFFER_SIZE = 128 +NVML_VGPU_NAME_BUFFER_SIZE = 64 +NVML_GRID_LICENSE_FEATURE_MAX_COUNT = 3 +NVML_VGPU_METADATA_OPAQUE_DATA_SIZE = sizeof(c_uint) + 256 +NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE = 256 +NVML_DEVICE_GPU_FRU_PART_NUMBER_BUFFER_SIZE = 0x14 # NV2080_GPU_MAX_PRODUCT_PART_NUMBER_LENGTH +NVML_PERF_MODES_BUFFER_SIZE = 2048 + +# Format strings +NVML_DEVICE_PCI_BUS_ID_LEGACY_FMT = "%04X:%02X:%02X.0" +NVML_DEVICE_PCI_BUS_ID_FMT = "%08X:%02X:%02X.0" + +NVML_VALUE_NOT_AVAILABLE_ulonglong = c_ulonglong(-1) +NVML_VALUE_NOT_AVAILABLE_uint = c_uint(-1) + +''' + Field Identifiers. + + All Identifiers pertain to a device. Each ID is only used once and is guaranteed never to change. +''' +NVML_FI_DEV_ECC_CURRENT = 1 # Current ECC mode. 1=Active. 0=Inactive +NVML_FI_DEV_ECC_PENDING = 2 # Pending ECC mode. 1=Active. 0=Inactive + +#ECC Count Totals +NVML_FI_DEV_ECC_SBE_VOL_TOTAL = 3 # Total single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_TOTAL = 4 # Total double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_AGG_TOTAL = 5 # Total single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_TOTAL = 6 # Total double bit aggregate (persistent) ECC errors +#Individual ECC locations +NVML_FI_DEV_ECC_SBE_VOL_L1 = 7 # L1 cache single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_L1 = 8 # L1 cache double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_L2 = 9 # L2 cache single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_L2 = 10 # L2 cache double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_DEV = 11 # Device memory single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_DEV = 12 # Device memory double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_REG = 13 # Register file single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_REG = 14 # Register file double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_TEX = 15 # Texture memory single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_TEX = 16 # Texture memory double bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_CBU = 17 # CBU double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_AGG_L1 = 18 # L1 cache single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_L1 = 19 # L1 cache double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_L2 = 20 # L2 cache single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_L2 = 21 # L2 cache double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_DEV = 22 # Device memory single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_DEV = 23 # Device memory double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_REG = 24 # Register File single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_REG = 25 # Register File double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_TEX = 26 # Texture memory single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_TEX = 27 # Texture memory double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_CBU = 28 # CBU double bit aggregate ECC errors + +# Page Retirement +NVML_FI_DEV_RETIRED_SBE = 29 # Number of retired pages because of single bit errors +NVML_FI_DEV_RETIRED_DBE = 30 # Number of retired pages because of double bit errors +NVML_FI_DEV_RETIRED_PENDING = 31 # If any pages are pending retirement. 1=yes. 0=no. + +# NvLink Flit Error Counters +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L0 = 32 # NVLink flow control CRC Error Counter for Lane 0 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L1 = 33 # NVLink flow control CRC Error Counter for Lane 1 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L2 = 34 # NVLink flow control CRC Error Counter for Lane 2 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L3 = 35 # NVLink flow control CRC Error Counter for Lane 3 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L4 = 36 # NVLink flow control CRC Error Counter for Lane 4 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L5 = 37 # NVLink flow control CRC Error Counter for Lane 5 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_TOTAL = 38 # NVLink flow control CRC Error Counter total for all Lanes + +# NvLink CRC Data Error Counters +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L0 = 39 # NVLink data CRC Error Counter for Lane 0 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L1 = 40 # NVLink data CRC Error Counter for Lane 1 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L2 = 41 # NVLink data CRC Error Counter for Lane 2 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L3 = 42 # NVLink data CRC Error Counter for Lane 3 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L4 = 43 # NVLink data CRC Error Counter for Lane 4 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L5 = 44 # NVLink data CRC Error Counter for Lane 5 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_TOTAL = 45 # NvLink data CRC Error Counter total for all Lanes + +# NvLink Replay Error Counters +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L0 = 46 # NVLink Replay Error Counter for Lane 0 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L1 = 47 # NVLink Replay Error Counter for Lane 1 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L2 = 48 # NVLink Replay Error Counter for Lane 2 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L3 = 49 # NVLink Replay Error Counter for Lane 3 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L4 = 50 # NVLink Replay Error Counter for Lane 4 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L5 = 51 # NVLink Replay Error Counter for Lane 5 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_TOTAL = 52 # NVLink Replay Error Counter total for all Lanes + +# NvLink Recovery Error Counters +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L0 = 53 # NVLink Recovery Error Counter for Lane 0 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L1 = 54 # NVLink Recovery Error Counter for Lane 1 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L2 = 55 # NVLink Recovery Error Counter for Lane 2 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L3 = 56 # NVLink Recovery Error Counter for Lane 3 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L4 = 57 # NVLink Recovery Error Counter for Lane 4 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L5 = 58 # NVLink Recovery Error Counter for Lane 5 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_TOTAL = 59 # NVLink Recovery Error Counter total for all Lanes + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L0 = 60 # NVLink Bandwidth Counter for Counter Set 0, Lane 0 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L1 = 61 # NVLink Bandwidth Counter for Counter Set 0, Lane 1 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L2 = 62 # NVLink Bandwidth Counter for Counter Set 0, Lane 2 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L3 = 63 # NVLink Bandwidth Counter for Counter Set 0, Lane 3 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L4 = 64 # NVLink Bandwidth Counter for Counter Set 0, Lane 4 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L5 = 65 # NVLink Bandwidth Counter for Counter Set 0, Lane 5 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_TOTAL = 66 # NVLink Bandwidth Counter Total for Counter Set 0, All Lanes + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L0 = 67 # NVLink Bandwidth Counter for Counter Set 1, Lane 0 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L1 = 68 # NVLink Bandwidth Counter for Counter Set 1, Lane 1 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L2 = 69 # NVLink Bandwidth Counter for Counter Set 1, Lane 2 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L3 = 70 # NVLink Bandwidth Counter for Counter Set 1, Lane 3 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L4 = 71 # NVLink Bandwidth Counter for Counter Set 1, Lane 4 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L5 = 72 # NVLink Bandwidth Counter for Counter Set 1, Lane 5 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_TOTAL = 73 # NVLink Bandwidth Counter Total for Counter Set 1, All Lanes + +# Perf Policy Counters +NVML_FI_DEV_PERF_POLICY_POWER = 74 # Perf Policy Counter for Power Policy +NVML_FI_DEV_PERF_POLICY_THERMAL = 75 # Perf Policy Counter for Thermal Policy +NVML_FI_DEV_PERF_POLICY_SYNC_BOOST = 76 # Perf Policy Counter for Sync boost Policy +NVML_FI_DEV_PERF_POLICY_BOARD_LIMIT = 77 # Perf Policy Counter for Board Limit +NVML_FI_DEV_PERF_POLICY_LOW_UTILIZATION = 78 # Perf Policy Counter for Low GPU Utilization Policy +NVML_FI_DEV_PERF_POLICY_RELIABILITY = 79 # Perf Policy Counter for Reliability Policy +NVML_FI_DEV_PERF_POLICY_TOTAL_APP_CLOCKS = 80 # Perf Policy Counter for Total App Clock Policy +NVML_FI_DEV_PERF_POLICY_TOTAL_BASE_CLOCKS = 81 # Perf Policy Counter for Total Base Clocks Policy + +# Memory temperatures +NVML_FI_DEV_MEMORY_TEMP = 82 # Memory temperature for the device + +# Energy Counter +NVML_FI_DEV_TOTAL_ENERGY_CONSUMPTION = 83 # Total energy consumption for the GPU in mJ since the driver was last reloaded + +# NVLink Speed +NVML_FI_DEV_NVLINK_SPEED_MBPS_L0 = 84 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L1 = 85 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L2 = 86 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L3 = 87 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L4 = 88 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L5 = 89 +NVML_FI_DEV_NVLINK_SPEED_MBPS_COMMON = 90 + +# NVLink Link Count +NVML_FI_DEV_NVLINK_LINK_COUNT = 91 + +# Page Retirement pending fields +NVML_FI_DEV_RETIRED_PENDING_SBE = 92 +NVML_FI_DEV_RETIRED_PENDING_DBE = 93 + +# PCIe replay and replay rollover counters +NVML_FI_DEV_PCIE_REPLAY_COUNTER = 94 +NVML_FI_DEV_PCIE_REPLAY_ROLLOVER_COUNTER = 95 + +# NvLink Flit Error Counters +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L6 = 96 # NVLink flow control CRC Error Counter for Lane 6 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L7 = 97 # NVLink flow control CRC Error Counter for Lane 7 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L8 = 98 # NVLink flow control CRC Error Counter for Lane 8 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L9 = 99 # NVLink flow control CRC Error Counter for Lane 9 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L10 = 100 # NVLink flow control CRC Error Counter for Lane 10 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L11 = 101 # NVLink flow control CRC Error Counter for Lane 11 + +# NvLink CRC Data Error Counters +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L6 = 102 # NVLink data CRC Error Counter for Lane 6 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L7 = 103 # NVLink data CRC Error Counter for Lane 7 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L8 = 104 # NVLink data CRC Error Counter for Lane 8 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L9 = 105 # NVLink data CRC Error Counter for Lane 9 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L10 = 106 # NVLink data CRC Error Counter for Lane 10 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L11 = 107 # NVLink data CRC Error Counter for Lane 11 + +# NvLink Replay Error Counters +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L6 = 108 # NVLink Replay Error Counter for Lane 6 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L7 = 109 # NVLink Replay Error Counter for Lane 7 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L8 = 110 # NVLink Replay Error Counter for Lane 8 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L9 = 111 # NVLink Replay Error Counter for Lane 9 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L10 = 112 # NVLink Replay Error Counter for Lane 10 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L11 = 113 # NVLink Replay Error Counter for Lane 11 + +# NvLink Recovery Error Counters +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L6 = 114 # NVLink Recovery Error Counter for Lane 6 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L7 = 115 # NVLink Recovery Error Counter for Lane 7 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L8 = 116 # NVLink Recovery Error Counter for Lane 8 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L9 = 117 # NVLink Recovery Error Counter for Lane 9 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L10 = 118 # NVLink Recovery Error Counter for Lane 10 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L11 = 119 # NVLink Recovery Error Counter for Lane 11 + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L6 = 120 # NVLink Bandwidth Counter for Counter Set 0, Lane 6 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L7 = 121 # NVLink Bandwidth Counter for Counter Set 0, Lane 7 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L8 = 122 # NVLink Bandwidth Counter for Counter Set 0, Lane 8 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L9 = 123 # NVLink Bandwidth Counter for Counter Set 0, Lane 9 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L10 = 124 # NVLink Bandwidth Counter for Counter Set 0, Lane 10 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L11 = 125 # NVLink Bandwidth Counter for Counter Set 0, Lane 11 + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L6 = 126 # NVLink Bandwidth Counter for Counter Set 1, Lane 6 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L7 = 127 # NVLink Bandwidth Counter for Counter Set 1, Lane 7 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L8 = 128 # NVLink Bandwidth Counter for Counter Set 1, Lane 8 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L9 = 129 # NVLink Bandwidth Counter for Counter Set 1, Lane 9 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L10 = 130 # NVLink Bandwidth Counter for Counter Set 1, Lane 10 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L11 = 131 # NVLink Bandwidth Counter for Counter Set 1, Lane 11 + +# NVLink Speed +NVML_FI_DEV_NVLINK_SPEED_MBPS_L6 = 132 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L7 = 133 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L8 = 134 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L9 = 135 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L10 = 136 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L11 = 137 + +# NVLink Throughput Counters +NVML_FI_DEV_NVLINK_THROUGHPUT_DATA_TX = 138 # NVLink TX Data throughput in KiB +NVML_FI_DEV_NVLINK_THROUGHPUT_DATA_RX = 139 # NVLink RX Data throughput in KiB +NVML_FI_DEV_NVLINK_THROUGHPUT_RAW_TX = 140 # NVLink TX Data + protocol overhead in KiB +NVML_FI_DEV_NVLINK_THROUGHPUT_RAW_RX = 141 # NVLink RX Data + protocol overhead in KiB + +# Row Remapper +NVML_FI_DEV_REMAPPED_COR = 142 +NVML_FI_DEV_REMAPPED_UNC = 143 +NVML_FI_DEV_REMAPPED_PENDING = 144 +NVML_FI_DEV_REMAPPED_FAILURE = 145 + +#Remote device NVLink ID +NVML_FI_DEV_NVLINK_REMOTE_NVLINK_ID = 146 + +# Number of NVLinks connected to NVSwitch +NVML_FI_DEV_NVSWITCH_CONNECTED_LINK_COUNT = 147 + +# NvLink ECC Data Error Counters +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L0 = 148 #< NVLink data ECC Error Counter for Link 0 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L1 = 149 #< NVLink data ECC Error Counter for Link 1 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L2 = 150 #< NVLink data ECC Error Counter for Link 2 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L3 = 151 #< NVLink data ECC Error Counter for Link 3 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L4 = 152 #< NVLink data ECC Error Counter for Link 4 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L5 = 153 #< NVLink data ECC Error Counter for Link 5 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L6 = 154 #< NVLink data ECC Error Counter for Link 6 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L7 = 155 #< NVLink data ECC Error Counter for Link 7 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L8 = 156 #< NVLink data ECC Error Counter for Link 8 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L9 = 157 #< NVLink data ECC Error Counter for Link 9 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L10 = 158 #< NVLink data ECC Error Counter for Link 10 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L11 = 159 #< NVLink data ECC Error Counter for Link 11 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_TOTAL = 160 #< NvLink data ECC Error Counter total for all Links + +NVML_FI_DEV_NVLINK_ERROR_DL_REPLAY = 161 +NVML_FI_DEV_NVLINK_ERROR_DL_RECOVERY = 162 +NVML_FI_DEV_NVLINK_ERROR_DL_CRC = 163 +NVML_FI_DEV_NVLINK_GET_SPEED = 164 +NVML_FI_DEV_NVLINK_GET_STATE = 165 +NVML_FI_DEV_NVLINK_GET_VERSION = 166 + +NVML_FI_DEV_NVLINK_GET_POWER_STATE = 167 +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD = 168 + +NVML_FI_DEV_PCIE_L0_TO_RECOVERY_COUNTER = 169 + +NVML_FI_DEV_C2C_LINK_COUNT = 170 +NVML_FI_DEV_C2C_LINK_GET_STATUS = 171 +NVML_FI_DEV_C2C_LINK_GET_MAX_BW = 172 + +NVML_FI_DEV_PCIE_COUNT_CORRECTABLE_ERRORS = 173 +NVML_FI_DEV_PCIE_COUNT_NAKS_RECEIVED = 174 +NVML_FI_DEV_PCIE_COUNT_RECEIVER_ERROR = 175 +NVML_FI_DEV_PCIE_COUNT_BAD_TLP = 176 +NVML_FI_DEV_PCIE_COUNT_NAKS_SENT = 177 +NVML_FI_DEV_PCIE_COUNT_BAD_DLLP = 178 +NVML_FI_DEV_PCIE_COUNT_NON_FATAL_ERROR = 179 +NVML_FI_DEV_PCIE_COUNT_FATAL_ERROR = 180 +NVML_FI_DEV_PCIE_COUNT_UNSUPPORTED_REQ = 181 +NVML_FI_DEV_PCIE_COUNT_LCRC_ERROR = 182 +NVML_FI_DEV_PCIE_COUNT_LANE_ERROR = 183 + +NVML_FI_DEV_IS_RESETLESS_MIG_SUPPORTED = 184 + +NVML_FI_DEV_POWER_AVERAGE = 185 +NVML_FI_DEV_POWER_INSTANT = 186 +NVML_FI_DEV_POWER_MIN_LIMIT = 187 +NVML_FI_DEV_POWER_MAX_LIMIT = 188 +NVML_FI_DEV_POWER_DEFAULT_LIMIT = 189 +NVML_FI_DEV_POWER_CURRENT_LIMIT = 190 +NVML_FI_DEV_ENERGY = 191 +NVML_FI_DEV_POWER_REQUESTED_LIMIT = 192 + +NVML_FI_DEV_TEMPERATURE_SHUTDOWN_TLIMIT = 193 +NVML_FI_DEV_TEMPERATURE_SLOWDOWN_TLIMIT = 194 +NVML_FI_DEV_TEMPERATURE_MEM_MAX_TLIMIT = 195 +NVML_FI_DEV_TEMPERATURE_GPU_MAX_TLIMIT = 196 + +NVML_FI_DEV_PCIE_COUNT_TX_BYTES = 197 +NVML_FI_DEV_PCIE_COUNT_RX_BYTES = 198 + +NVML_FI_DEV_IS_MIG_MODE_INDEPENDENT_MIG_QUERY_CAPABLE = 199 + +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_MAX = 200 + +NVML_FI_DEV_NVLINK_COUNT_XMIT_PACKETS = 201 +NVML_FI_DEV_NVLINK_COUNT_XMIT_BYTES = 202 +NVML_FI_DEV_NVLINK_COUNT_RCV_PACKETS = 203 +NVML_FI_DEV_NVLINK_COUNT_RCV_BYTES = 204 +NVML_FI_DEV_NVLINK_COUNT_VL15_DROPPED = 205 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_MALFORMED_PACKET_ERRORS = 206 +NVML_FI_DEV_NVLINK_COUNT_BUFFER_OVERRUN_ERRORS = 207 +NVML_FI_DEV_NVLINK_COUNT_RCV_ERRORS = 208 +NVML_FI_DEV_NVLINK_COUNT_RCV_REMOTE_ERRORS = 209 +NVML_FI_DEV_NVLINK_COUNT_RCV_GENERAL_ERRORS = 210 +NVML_FI_DEV_NVLINK_COUNT_LOCAL_LINK_INTEGRITY_ERRORS = 211 +NVML_FI_DEV_NVLINK_COUNT_XMIT_DISCARDS = 212 + +NVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_SUCCESSFUL_EVENTS = 213 +NVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_FAILED_EVENTS = 214 +NVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_EVENTS = 215 + +NVML_FI_DEV_NVLINK_COUNT_RAW_BER_LANE0 = 216 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_RAW_BER_LANE1 = 217 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_RAW_BER = 218 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_EFFECTIVE_ERRORS = 219 +NVML_FI_DEV_NVLINK_COUNT_EFFECTIVE_BER = 220 +NVML_FI_DEV_NVLINK_COUNT_SYMBOL_ERRORS = 221 +NVML_FI_DEV_NVLINK_COUNT_SYMBOL_BER = 222 + +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_MIN = 223 +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS = 224 # Values are in the form NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_* +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_SUPPORTED = 225 + +NVML_FI_DEV_RESET_STATUS = 226 # Deprecated use NVML_FI_DEV_GET_GPU_RECOVERY_ACTION instead +NVML_FI_DEV_DRAIN_AND_RESET_STATUS = 227 # Deprecated use NVML_FI_DEV_GET_GPU_RECOVERY_ACTION instead +NVML_FI_DEV_PCIE_OUTBOUND_ATOMICS_MASK = 228 +NVML_FI_DEV_PCIE_INBOUND_ATOMICS_MASK = 229 +NVML_FI_DEV_GET_GPU_RECOVERY_ACTION = 230 + +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_0 = 235 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_1 = 236 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_2 = 237 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_3 = 238 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_4 = 239 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_5 = 240 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_6 = 241 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_7 = 242 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_8 = 243 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_9 = 244 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_10 = 245 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_11 = 246 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_12 = 247 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_13 = 248 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_14 = 249 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_15 = 250 +NVML_FI_PWR_SMOOTHING_ENABLED = 251 # Enablement (0/DISABLED or 1/ENABLED) +NVML_FI_PWR_SMOOTHING_PRIV_LVL = 252 # Current privilege level +NVML_FI_PWR_SMOOTHING_IMM_RAMP_DOWN_ENABLED = 253 # Immediate ramp down enablement (0/DISABLED or 1/ENABLED) +NVML_FI_PWR_SMOOTHING_APPLIED_TMP_CEIL = 254 # Applied TMP ceiling value +NVML_FI_PWR_SMOOTHING_APPLIED_TMP_FLOOR = 255 # Applied TMP floor value +NVML_FI_PWR_SMOOTHING_MAX_PERCENT_TMP_FLOOR_SETTING = 256 # Max % TMP Floor value +NVML_FI_PWR_SMOOTHING_MIN_PERCENT_TMP_FLOOR_SETTING = 257 # Min % TMP Floor value +NVML_FI_PWR_SMOOTHING_HW_CIRCUITRY_PERCENT_LIFETIME_REMAINING = 258 # HW Circuitry % lifetime remaining +NVML_FI_PWR_SMOOTHING_MAX_NUM_PRESET_PROFILES = 259 # Max number of preset profiles +NVML_FI_PWR_SMOOTHING_PROFILE_PERCENT_TMP_FLOOR = 260 # % TMP floor for a given profile +NVML_FI_PWR_SMOOTHING_PROFILE_RAMP_UP_RATE = 261 # Ramp up rate in mW/s for a given profile +NVML_FI_PWR_SMOOTHING_PROFILE_RAMP_DOWN_RATE = 262 # Ramp down rate in mW/s for a given profile +NVML_FI_PWR_SMOOTHING_PROFILE_RAMP_DOWN_HYST_VAL = 263 # Ramp down hysteresis value in ms for a given profile +NVML_FI_PWR_SMOOTHING_ACTIVE_PRESET_PROFILE = 264 # Active preset profile number +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_PERCENT_TMP_FLOOR = 265 # % TMP floor for a given profile +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_UP_RATE = 266 # Ramp up rate in mW/s for a given profile +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_DOWN_RATE = 267 # Ramp down rate in mW/s for a given profile +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_DOWN_HYST_VAL = 268 # Ramp down hysteresis value in ms for a given profile + +NVML_FI_MAX = 269 # One greater than the largest field ID defined above + +# NVML_FI_DEV_NVLINK_GET_STATE state enums +NVML_NVLINK_STATE_INACTIVE = 0x0 +NVML_NVLINK_STATE_ACTIVE = 0x1 +NVML_NVLINK_STATE_SLEEP = 0x2 + +NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_100US = 0 # NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS +NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_50US = 1 # NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS + +## Enums needed for the method nvmlDeviceGetVirtualizationMode and nvmlDeviceSetVirtualizationMode +NVML_GPU_VIRTUALIZATION_MODE_NONE = 0 # Represents Bare Metal GPU +NVML_GPU_VIRTUALIZATION_MODE_PASSTHROUGH = 1 # Device is associated with GPU-Passthorugh +NVML_GPU_VIRTUALIZATION_MODE_VGPU = 2 # Device is associated with vGPU inside virtual machine. +NVML_GPU_VIRTUALIZATION_MODE_HOST_VGPU = 3 # Device is associated with VGX hypervisor in vGPU mode +NVML_GPU_VIRTUALIZATION_MODE_HOST_VSGA = 4 # Device is associated with VGX hypervisor in vSGA mode + +## Lib loading ## +nvmlLib = None +libLoadLock = threading.Lock() +_nvmlLib_refcount = 0 # Incremented on each nvmlInit and decremented on nvmlShutdown + +## vGPU Management +_nvmlVgpuTypeId_t = c_uint +_nvmlVgpuInstance_t = c_uint + +_nvmlVgpuVmIdType_t = c_uint +NVML_VGPU_VM_ID_DOMAIN_ID = 0 +NVML_VGPU_VM_ID_UUID = 1 + +_nvmlGridLicenseFeatureCode_t = c_uint +NVML_GRID_LICENSE_FEATURE_CODE_UNKNOWN = 0 +NVML_GRID_LICENSE_FEATURE_CODE_VGPU = 1 +NVML_GRID_LICENSE_FEATURE_CODE_NVIDIA_RTX = 2 +NVML_GRID_LICENSE_FEATURE_CODE_VWORKSTATION = 2 # deprecated, use NVML_GRID_LICENSE_FEATURE_CODE_NVIDIA_RTX. +NVML_GRID_LICENSE_FEATURE_CODE_GAMING = 3 +NVML_GRID_LICENSE_FEATURE_CODE_COMPUTE = 4 + +_nvmlGridLicenseExpiryStatus_t = c_uint8 +NVML_GRID_LICENSE_EXPIRY_NOT_AVAILABLE = 0, # Expiry information not available +NVML_GRID_LICENSE_EXPIRY_INVALID = 1, # Invalid expiry or error fetching expiry +NVML_GRID_LICENSE_EXPIRY_VALID = 2, # Valid expiry +NVML_GRID_LICENSE_EXPIRY_NOT_APPLICABLE = 3, # Expiry not applicable +NVML_GRID_LICENSE_EXPIRY_PERMANENT = 4, # Permanent expiry + +_nvmlVgpuCapability_t = c_uint +NVML_VGPU_CAP_NVLINK_P2P = 0 # vGPU P2P over NVLink is supported +NVML_VGPU_CAP_GPUDIRECT = 1 # GPUDirect capability is supported +NVML_VGPU_CAP_MULTI_VGPU_EXCLUSIVE = 2 # vGPU profile cannot be mixed with other vGPU profiles in same VM +NVML_VGPU_CAP_EXCLUSIVE_TYPE = 3 # vGPU profile cannot run on a GPU alongside other profiles of different type +NVML_VGPU_CAP_EXCLUSIVE_SIZE = 4 # vGPU profile cannot run on a GPU alongside other profiles of different size +NVML_VGPU_CAP_COUNT = 5 + +_nvmlVgpuDriverCapability_t = c_uint +NVML_VGPU_DRIVER_CAP_HETEROGENEOUS_MULTI_VGPU = 0 # Supports mixing of different vGPU profiles within one guest VM +NVML_VGPU_DRIVER_CAP_WARM_UPDATE = 1 # Supports FSR and warm update of vGPU host driver without terminating the running guest VM +NVML_VGPU_DRIVER_CAP_COUNT = 2 + +_nvmlDeviceVgpuCapability_t = c_uint +NVML_DEVICE_VGPU_CAP_FRACTIONAL_MULTI_VGPU = 0 # Query whether the fractional vGPU profiles on this GPU can be used in multi-vGPU configurations +NVML_DEVICE_VGPU_CAP_HETEROGENEOUS_TIMESLICE_PROFILES = 1 # Query whether the GPU supports concurrent execution of timesliced vGPU profiles of differing types +NVML_DEVICE_VGPU_CAP_HETEROGENEOUS_TIMESLICE_SIZES = 2 # Query whether the GPU supports concurrent execution of timesliced vGPU profiles of differing framebuffer sizes +NVML_DEVICE_VGPU_CAP_READ_DEVICE_BUFFER_BW = 3 # Query the GPU's read_device_buffer expected bandwidth capacity in megabytes per second +NVML_DEVICE_VGPU_CAP_WRITE_DEVICE_BUFFER_BW = 4 # Query the GPU's write_device_buffer expected bandwidth capacity in megabytes per second +NVML_DEVICE_VGPU_CAP_DEVICE_STREAMING = 5 # Query whether the vGPU profiles on the GPU supports migration data streaming +NVML_DEVICE_VGPU_CAP_MINI_QUARTER_GPU = 6 # Set/Get support of mini-quarter vGPU profiles +NVML_DEVICE_VGPU_CAP_COMPUTE_MEDIA_ENGINE_GPU = 7 # Set/Get support for compute media engine vGPU profiles +NVML_DEVICE_VGPU_CAP_WARM_UPDATE = 8 # Query whether the GPU supports FSR and warm update +NVML_DEVICE_VGPU_CAP_HOMOGENEOUS_PLACEMENTS = 9 # Query whether the GPU supports reporting of placements of timesliced vGPU profiles with identical framebuffer sizes +NVML_DEVICE_VGPU_CAP_COUNT = 10 + +_nvmlVgpuGuestInfoState_t = c_uint +NVML_VGPU_INSTANCE_GUEST_INFO_STATE_UNINITIALIZED = 0 +NVML_VGPU_INSTANCE_GUEST_INFO_STATE_INITIALIZED = 1 + +_nvmlVgpuVmCompatibility_t = c_uint +NVML_VGPU_VM_COMPATIBILITY_NONE = 0x0 +NVML_VGPU_VM_COMPATIBILITY_COLD = 0x1 +NVML_VGPU_VM_COMPATIBILITY_HIBERNATE = 0x2 +NVML_VGPU_VM_COMPATIBILITY_SLEEP = 0x4 +NVML_VGPU_VM_COMPATIBILITY_LIVE = 0x8 + +_nvmlVgpuPgpuCompatibilityLimitCode_t = c_uint +NVML_VGPU_COMPATIBILITY_LIMIT_NONE = 0x0 +NVML_VGPU_COMPATIBILITY_LIMIT_HOST_DRIVER = 0x1 +NVML_VGPU_COMPATIBILITY_LIMIT_GUEST_DRIVER = 0x2 +NVML_VGPU_COMPATIBILITY_LIMIT_GPU = 0x4 +NVML_VGPU_COMPATIBILITY_LIMIT_OTHER = 0x80000000 + +_nvmlHostVgpuMode_t = c_uint +NVML_HOST_VGPU_MODE_NON_SRIOV = 0 +NVML_HOST_VGPU_MODE_SRIOV = 1 + +_nvmlConfComputeGpusReadyState_t = c_uint +NVML_CC_ACCEPTING_CLIENT_REQUESTS_FALSE = 0 +NVML_CC_ACCEPTING_CLIENT_REQUESTS_TRUE = 1 + +_nvmlConfComputeGpuCaps_t = c_uint +NVML_CC_SYSTEM_GPUS_CC_NOT_CAPABLE = 0 +NVML_CC_SYSTEM_GPUS_CC_CAPABLE = 1 + +_nvmlConfComputeCpuCaps_t = c_uint +NVML_CC_SYSTEM_CPU_CAPS_NONE = 0 +NVML_CC_SYSTEM_CPU_CAPS_AMD_SEV = 1 +NVML_CC_SYSTEM_CPU_CAPS_INTEL_TDX = 2 +NVML_CC_SYSTEM_CPU_CAPS_AMD_SEV_SNP = 3 +NVML_CC_SYSTEM_CPU_CAPS_AMD_SNP_VTOM = 4 + +_nvmlConfComputeDevToolsMode_t = c_uint +NVML_CC_SYSTEM_DEVTOOLS_MODE_OFF = 0 +NVML_CC_SYSTEM_DEVTOOLS_MODE_ON = 1 + +NVML_CC_SYSTEM_MULTIGPU_NONE = 0 +NVML_CC_SYSTEM_MULTIGPU_PROTECTED_PCIE = 1 + +NVML_CC_SYSTEM_ENVIRONMENT_UNAVAILABLE = 0 +NVML_CC_SYSTEM_ENVIRONMENT_SIM = 1 +NVML_CC_SYSTEM_ENVIRONMENT_PROD = 2 + +_nvmlConfComputeCcFeature_t = c_uint +NVML_CC_SYSTEM_FEATURE_DISABLED = 0 +NVML_CC_SYSTEM_FEATURE_ENABLED = 1 + +_nvmlConfComputeCcKeyRotationThreshAttackerAdv_t = c_uint +NVML_CC_KEY_ROTATION_THRESH_ATTACKER_ADVANTAGE_MIN = 50 +NVML_CC_KEY_ROTATION_THRESH_ATTACKER_ADVANTAGE_MAX = 65 + +# GSP firmware +NVML_GSP_FIRMWARE_VERSION_BUF_SIZE = 0x40 + +class NVMLLibraryMismatchError(Exception): + pass + +## Error Checking ## +class NVMLError(Exception): + _valClassMapping = dict() + # List of currently known error codes + _errcode_to_string = { + NVML_ERROR_UNINITIALIZED: "Uninitialized", + NVML_ERROR_INVALID_ARGUMENT: "Invalid Argument", + NVML_ERROR_NOT_SUPPORTED: "Not Supported", + NVML_ERROR_NO_PERMISSION: "Insufficient Permissions", + NVML_ERROR_ALREADY_INITIALIZED: "Already Initialized", + NVML_ERROR_NOT_FOUND: "Not Found", + NVML_ERROR_INSUFFICIENT_SIZE: "Insufficient Size", + NVML_ERROR_INSUFFICIENT_POWER: "Insufficient External Power", + NVML_ERROR_DRIVER_NOT_LOADED: "Driver Not Loaded", + NVML_ERROR_TIMEOUT: "Timeout", + NVML_ERROR_IRQ_ISSUE: "Interrupt Request Issue", + NVML_ERROR_LIBRARY_NOT_FOUND: "NVML Shared Library Not Found", + NVML_ERROR_FUNCTION_NOT_FOUND: "Function Not Found", + NVML_ERROR_CORRUPTED_INFOROM: "Corrupted infoROM", + NVML_ERROR_GPU_IS_LOST: "GPU is lost", + NVML_ERROR_RESET_REQUIRED: "GPU requires restart", + NVML_ERROR_OPERATING_SYSTEM: "The operating system has blocked the request.", + NVML_ERROR_LIB_RM_VERSION_MISMATCH: "RM has detected an NVML/RM version mismatch.", + NVML_ERROR_MEMORY: "Insufficient Memory", + NVML_ERROR_UNKNOWN: "Unknown Error", + } + def __new__(typ, value): + ''' + Maps value to a proper subclass of NVMLError. + See _extractNVMLErrorsAsClasses function for more details + ''' + if typ == NVMLError: + typ = NVMLError._valClassMapping.get(value, typ) + obj = Exception.__new__(typ) + obj.value = value + return obj + def __str__(self): + try: + if self.value not in NVMLError._errcode_to_string: + NVMLError._errcode_to_string[self.value] = str(nvmlErrorString(self.value)) + return NVMLError._errcode_to_string[self.value] + except NVMLError: + return "NVML Error with code %d" % self.value + def __eq__(self, other): + return self.value == other.value + +def nvmlExceptionClass(nvmlErrorCode): + if nvmlErrorCode not in NVMLError._valClassMapping: + raise ValueError('nvmlErrorCode %s is not valid' % nvmlErrorCode) + return NVMLError._valClassMapping[nvmlErrorCode] + +def _extractNVMLErrorsAsClasses(): + ''' + Generates a hierarchy of classes on top of NVMLError class. + + Each NVML Error gets a new NVMLError subclass. This way try,except blocks can filter appropriate + exceptions more easily. + + NVMLError is a parent class. Each NVML_ERROR_* gets it's own subclass. + e.g. NVML_ERROR_ALREADY_INITIALIZED will be turned into NVMLError_AlreadyInitialized + ''' + this_module = sys.modules[__name__] + nvmlErrorsNames = [x for x in dir(this_module) if x.startswith("NVML_ERROR_")] + for err_name in nvmlErrorsNames: + # e.g. Turn NVML_ERROR_ALREADY_INITIALIZED into NVMLError_AlreadyInitialized + class_name = "NVMLError_" + string.capwords(err_name.replace("NVML_ERROR_", ""), "_").replace("_", "") + err_val = getattr(this_module, err_name) + def gen_new(val): + def new(typ): + obj = NVMLError.__new__(typ, val) + return obj + return new + new_error_class = type(class_name, (NVMLError,), {'__new__': gen_new(err_val)}) + new_error_class.__module__ = __name__ + setattr(this_module, class_name, new_error_class) + NVMLError._valClassMapping[err_val] = new_error_class +_extractNVMLErrorsAsClasses() + +def _nvmlCheckReturn(ret): + if (ret != NVML_SUCCESS): + raise NVMLError(ret) + return ret + +## Function access ## +_nvmlGetFunctionPointer_cache = dict() # function pointers are cached to prevent unnecessary libLoadLock locking +def _nvmlGetFunctionPointer(name): + global nvmlLib + + if name in _nvmlGetFunctionPointer_cache: + return _nvmlGetFunctionPointer_cache[name] + + libLoadLock.acquire() + try: + # ensure library was loaded + if (nvmlLib == None): + raise NVMLError(NVML_ERROR_UNINITIALIZED) + try: + _nvmlGetFunctionPointer_cache[name] = getattr(nvmlLib, name) + return _nvmlGetFunctionPointer_cache[name] + except AttributeError: + raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND) + finally: + # lock is always freed + libLoadLock.release() + +## Alternative object +# Allows the object to be printed +# Allows mismatched types to be assigned +# - like None when the Structure variant requires c_uint +class nvmlFriendlyObject(object): + def __init__(self, dictionary): + for x in dictionary: + setattr(self, x, dictionary[x]) + def __str__(self): + return self.__dict__.__str__() + +def nvmlStructToFriendlyObject(struct): + d = {} + for x in struct._fields_: + key = x[0] + value = getattr(struct, key) + # only need to convert from bytes if bytes, no need to check python version. + d[key] = value.decode() if isinstance(value, bytes) else value + obj = nvmlFriendlyObject(d) + return obj + +# pack the object so it can be passed to the NVML library +def nvmlFriendlyObjectToStruct(obj, model): + for x in model._fields_: + key = x[0] + value = obj.__dict__[key] + # any c_char_p in python3 needs to be bytes, default encoding works fine. + if sys.version_info >= (3,): + setattr(model, key, value.encode()) + else: + setattr(model, key, value) + return model + +## Unit structures +class struct_c_nvmlUnit_t(Structure): + pass # opaque handle +c_nvmlUnit_t = POINTER(struct_c_nvmlUnit_t) + +class _PrintableStructure(Structure): + """ + Abstract class that produces nicer __str__ output than ctypes.Structure. + e.g. instead of: + >>> print str(obj) + + this class will print + class_name(field_name: formatted_value, field_name: formatted_value) + + _fmt_ dictionary of -> + e.g. class that has _field_ 'hex_value', c_uint could be formatted with + _fmt_ = {"hex_value" : "%08X"} + to produce nicer output. + Default fomratting string for all fields can be set with key "" like: + _fmt_ = {"" : "%d MHz"} # e.g all values are numbers in MHz. + If not set it's assumed to be just "%s" + + Exact format of returned str from this class is subject to change in the future. + """ + _fmt_ = {} + def __str__(self): + result = [] + for x in self._fields_: + key = x[0] + value = getattr(self, key) + fmt = "%s" + if key in self._fmt_: + fmt = self._fmt_[key] + elif "" in self._fmt_: + fmt = self._fmt_[""] + result.append(("%s: " + fmt) % (key, value)) + return self.__class__.__name__ + "(" + ", ".join(result) + ")" + + def __getattribute__(self, name): + res = super(_PrintableStructure, self).__getattribute__(name) + # need to convert bytes to unicode for python3 don't need to for python2 + # Python 2 strings are of both str and bytes + # Python 3 strings are not of type bytes + # ctypes should convert everything to the correct values otherwise + if isinstance(res, bytes): + if isinstance(res, str): + return res + return res.decode() + return res + + def __setattr__(self, name, value): + if isinstance(value, str): + # encoding a python2 string returns the same value, since python2 strings are bytes already + # bytes passed in python3 will be ignored. + value = value.encode() + super(_PrintableStructure, self).__setattr__(name, value) + +class c_nvmlUnitInfo_t(_PrintableStructure): + _fields_ = [ + ('name', c_char * 96), + ('id', c_char * 96), + ('serial', c_char * 96), + ('firmwareVersion', c_char * 96), + ] + +class c_nvmlC2cModeInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('isC2cEnabled', c_uint) + ] + +nvmlC2cModeInfo_v1 = 0x1000008; + +class c_nvmlLedState_t(_PrintableStructure): + _fields_ = [ + ('cause', c_char * 256), + ('color', _nvmlLedColor_t), + ] + +class c_nvmlPSUInfo_t(_PrintableStructure): + _fields_ = [ + ('state', c_char * 256), + ('current', c_uint), + ('voltage', c_uint), + ('power', c_uint), + ] + +class c_nvmlUnitFanInfo_t(_PrintableStructure): + _fields_ = [ + ('speed', c_uint), + ('state', _nvmlFanState_t), + ] + +class c_nvmlUnitFanSpeeds_t(_PrintableStructure): + _fields_ = [ + ('fans', c_nvmlUnitFanInfo_t * 24), + ('count', c_uint) + ] + +## Device structures +class struct_c_nvmlDevice_t(Structure): + pass # opaque handle +c_nvmlDevice_t = POINTER(struct_c_nvmlDevice_t) + +class nvmlPciInfoExt_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('domain', c_uint), + ('bus', c_uint), + ('device', c_uint), + ('pciDeviceId', c_uint), + ('pciSubSystemId', c_uint), + ('baseClass', c_uint), + ('subClass', c_uint), + ('busId', c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE), + ] + _fmt_ = { + 'version' : "0x%04X", + 'domain' : "0x%04X", + 'bus' : "0x%02X", + 'device' : "0x%02X", + 'pciDeviceId' : "0x%08X", + 'pciSubSystemId' : "0x%08X", + 'baseClass' : "0x%01X", + 'subClass' : "0x%01X", + } + +nvmlPciInfoExt_v1 = 0x1000040 + +# Legacy pciInfo used for _v1 and _v2 +class nvmlPciInfo_v2_t(_PrintableStructure): + _fields_ = [ + ('busId', c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE), + ('domain', c_uint), + ('bus', c_uint), + ('device', c_uint), + ('pciDeviceId', c_uint), + + # Added in 2.285 + ('pciSubSystemId', c_uint), + ('reserved0', c_uint), + ('reserved1', c_uint), + ('reserved2', c_uint), + ('reserved3', c_uint), + ] + _fmt_ = { + 'domain' : "0x%04X", + 'bus' : "0x%02X", + 'device' : "0x%02X", + 'pciDeviceId' : "0x%08X", + 'pciSubSystemId' : "0x%08X", + } + +class nvmlPciInfo_t(_PrintableStructure): + _fields_ = [ + # Moved to the new busId location below + ('busIdLegacy', c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE), + ('domain', c_uint), + ('bus', c_uint), + ('device', c_uint), + ('pciDeviceId', c_uint), + + # Added in 2.285 + ('pciSubSystemId', c_uint), + # New busId replaced the long deprecated and reserved fields with a + # field of the same size in 9.0 + ('busId', c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE), + ] + _fmt_ = { + 'domain' : "0x%08X", + 'bus' : "0x%02X", + 'device' : "0x%02X", + 'pciDeviceId' : "0x%08X", + 'pciSubSystemId' : "0x%08X", + } + +class c_nvmlSystemDriverBranchInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ("branch", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ] + +SystemDriverBranchInfo_v1 = 0x1000054 + +class c_nvmlExcludedDeviceInfo_t(_PrintableStructure): + _fields_ = [ + ('pci', nvmlPciInfo_t), + ('uuid', c_char * NVML_DEVICE_UUID_BUFFER_SIZE) + ] + +class nvmlNvLinkUtilizationControl_t(_PrintableStructure): + _fields_ = [ + ('units', _nvmlNvLinkUtilizationCountUnits_t), + ('pktfilter', _nvmlNvLinkUtilizationCountPktTypes_t), + ] + +class c_nvmlMemory_t(_PrintableStructure): + _fields_ = [ + ('total', c_ulonglong), + ('free', c_ulonglong), + ('used', c_ulonglong), + ] + _fmt_ = {'': "%d B"} + +class c_nvmlMemory_v2_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('total', c_ulonglong), + ('reserved', c_ulonglong), + ('free', c_ulonglong), + ('used', c_ulonglong), + ] + _fmt_ = {'': "%d B"} + +nvmlMemory_v2 = 0x02000028 + +class c_nvmlBAR1Memory_t(_PrintableStructure): + _fields_ = [ + ('bar1Total', c_ulonglong), + ('bar1Free', c_ulonglong), + ('bar1Used', c_ulonglong), + ] + _fmt_ = {'': "%d B"} + +class nvmlClkMonFaultInfo_t(Structure): + _fields_ = [("clkApiDomain", c_uint), + ("clkDomainFaultMask", c_uint) + ] + +MAX_CLK_DOMAINS = 32 + +class nvmlClkMonStatus_t(Structure): + _fields_ = [("bGlobalStatus", c_uint), + ("clkMonListSize", c_uint), + ("clkMonList", nvmlClkMonFaultInfo_t * MAX_CLK_DOMAINS) + ] + +# On Windows with the WDDM driver, usedGpuMemory is reported as None +# Code that processes this structure should check for None, I.E. +# +# if (info.usedGpuMemory == None): +# # TODO handle the error +# pass +# else: +# print("Using %d MiB of memory" % (info.usedGpuMemory / 1024 / 1024)) +# endif +# +# See NVML documentation for more information +class c_nvmlProcessInfo_v2_t(_PrintableStructure): + _fields_ = [ + ('pid', c_uint), + ('usedGpuMemory', c_ulonglong), + ('gpuInstanceId', c_uint), + ('computeInstanceId', c_uint), + ] + _fmt_ = {'usedGpuMemory': "%d B"} + +c_nvmlProcessInfo_v3_t = c_nvmlProcessInfo_v2_t + +c_nvmlProcessInfo_t = c_nvmlProcessInfo_v3_t + +_nvmlProcessMode_t = c_uint +NVML_PROCESS_MODE_COMPUTE = 0 +NVML_PROCESS_MODE_GRAPHICS = 1 +NVML_PROCESS_MODE_MPS = 2 + +class c_nvmlProcessDetail_v1_t(Structure): + _fields_ = [ + ('pid', c_uint), + ('usedGpuMemory', c_ulonglong), + ('gpuInstanceId', c_uint), + ('computeInstanceId', c_uint), + ('usedGpuCcProtectedMemory', c_ulonglong), + ] + +class c_nvmlProcessDetailList_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('mode', _nvmlProcessMode_t), + ('numProcArrayEntries', c_uint), + ('procArray', POINTER(c_nvmlProcessDetail_v1_t)), + ] + _fmt_ = {'numProcArrayEntries': "%d B"} + +c_nvmlProcessDetailList_t = c_nvmlProcessDetailList_v1_t + +nvmlProcessDetailList_v1 = 0x1000018 + +class c_nvmlBridgeChipInfo_t(_PrintableStructure): + _fields_ = [ + ('type', _nvmlBridgeChipType_t), + ('fwVersion', c_uint), + ] + +class c_nvmlBridgeChipHierarchy_t(_PrintableStructure): + _fields_ = [ + ('bridgeCount', c_uint), + ('bridgeChipInfo', c_nvmlBridgeChipInfo_t * 128), + ] + +class c_nvmlEccErrorCounts_t(_PrintableStructure): + _fields_ = [ + ('l1Cache', c_ulonglong), + ('l2Cache', c_ulonglong), + ('deviceMemory', c_ulonglong), + ('registerFile', c_ulonglong), + ] + +class c_nvmlUtilization_t(_PrintableStructure): + _fields_ = [ + ('gpu', c_uint), + ('memory', c_uint), + ] + _fmt_ = {'': "%d %%"} + +# Added in 2.285 +class c_nvmlHwbcEntry_t(_PrintableStructure): + _fields_ = [ + ('hwbcId', c_uint), + ('firmwareVersion', c_char * 32), + ] + +class c_nvmlValue_t(Union): + _fields_ = [ + ('dVal', c_double), + ('uiVal', c_uint), + ('ulVal', c_ulong), + ('ullVal', c_ulonglong), + ('sllVal', c_longlong), + ('siVal', c_int), + ('usVal', c_ushort), + ] + +class c_nvmlSample_t(_PrintableStructure): + _fields_ = [ + ('timeStamp', c_ulonglong), + ('sampleValue', c_nvmlValue_t), + ] + +class c_nvmlViolationTime_t(_PrintableStructure): + _fields_ = [ + ('referenceTime', c_ulonglong), + ('violationTime', c_ulonglong), + ] + +class c_nvmlFieldValue_t(_PrintableStructure): + _fields_ = [ + ('fieldId', c_uint32), + ('scopeId', c_uint32), + ('timestamp', c_int64), + ('latencyUsec', c_int64), + ('valueType', _nvmlValueType_t), + ('nvmlReturn', _nvmlReturn_t), + ('value', c_nvmlValue_t) + ] + +NVML_NVLINK_TOTAL_SUPPORTED_BW_MODES = 23 + +nvmlNvlinkSupportedBwModes_v1 = 0x100001c +class c_nvmlNvlinkSupportedBwModes_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('bwModes', c_uint8 * NVML_NVLINK_TOTAL_SUPPORTED_BW_MODES), + ('totalBwModes', c_uint8) + ] + + def __init__(self): + super(c_nvmlNvlinkSupportedBwModes_v1_t, self).__init__(version=nvmlNvlinkSupportedBwModes_v1) + +nvmlNvlinkGetBwMode_v1 = 0x100000c +class c_nvmlNvlinkGetBwMode_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('bIsBest', c_uint), + ('bwMode', c_uint8) + ] + + def __init__(self): + super(c_nvmlNvlinkGetBwMode_v1_t, self).__init__(version=nvmlNvlinkGetBwMode_v1) + +nvmlNvlinkSetBwMode_v1 = 0x100000c +class c_nvmlNvlinkSetBwMode_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('bSetBest', c_uint), + ('bwMode', c_uint8) + ] + + def __init__(self): + super(c_nvmlNvlinkSetBwMode_v1_t, self).__init__(version=nvmlNvlinkSetBwMode_v1) + +class c_nvmlVgpuHeterogeneousMode_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('mode', c_uint), + ] + +VgpuHeterogeneousMode_v1 = 0x1000008 + +class c_nvmlVgpuPlacementId_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('placementId', c_uint), + ] + +VgpuPlacementId_v1 = 0x1000008 + +class c_nvmlVgpuPlacementList_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('count', c_uint), + ('placementSize', c_uint), + ('placementIds', POINTER(c_uint)), + ] + +VgpuPlacementList_v1 = 0x1000018 + +NVML_VGPU_PGPU_HETEROGENEOUS_MODE = 0 +NVML_VGPU_PGPU_HOMOGENEOUS_MODE = 1 + +class c_nvmlVgpuPlacementList_v2_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('placementSize', c_uint), + ('count', c_uint), + ('placementIds', POINTER(c_uint)), + ('mode', c_uint), + ] + +VgpuPlacementList_v2 = 0x2000020 + +class c_nvmlVgpuTypeBar1Info_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('bar1Size', c_ulonglong), + ] + +VgpuTypeBar1Info_v1 = 0x1000010 + +class c_nvmlVgpuInstanceUtilizationSample_t(_PrintableStructure): + _fields_ = [ + ('vgpuInstance', _nvmlVgpuInstance_t), + ('timeStamp', c_ulonglong), + ('smUtil', c_nvmlValue_t), + ('memUtil', c_nvmlValue_t), + ('encUtil', c_nvmlValue_t), + ('decUtil', c_nvmlValue_t), + ] + +class c_nvmlVgpuInstanceUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('timeStamp', c_ulonglong), + ('vgpuInstance', _nvmlVgpuInstance_t), + ('smUtil', c_nvmlValue_t), + ('memUtil', c_nvmlValue_t), + ('encUtil', c_nvmlValue_t), + ('decUtil', c_nvmlValue_t), + ('jpgUtil', c_nvmlValue_t), + ('ofaUtil', c_nvmlValue_t), + ] + +class c_nvmlVgpuInstancesUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('sampleValType', _nvmlValueType_t), + ('vgpuInstanceCount', c_uint), + ('lastSeenTimeStamp', c_ulonglong), + ('vgpuUtilArray', POINTER(c_nvmlVgpuInstanceUtilizationInfo_v1_t)), + ] + +VgpuInstancesUtilizationInfo_v1 = 0x01000020 + +class c_nvmlVgpuProcessUtilizationSample_t(_PrintableStructure): + _fields_ = [ + ('vgpuInstance', _nvmlVgpuInstance_t), + ('pid', c_uint), + ('processName', c_char * NVML_VGPU_NAME_BUFFER_SIZE), + ('timeStamp', c_ulonglong), + ('smUtil', c_uint), + ('memUtil', c_uint), + ('encUtil', c_uint), + ('decUtil', c_uint), + ] + +class c_nvmlVgpuProcessUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('processName', c_char * NVML_VGPU_NAME_BUFFER_SIZE), + ('timeStamp', c_ulonglong), + ('vgpuInstance', _nvmlVgpuInstance_t), + ('pid', c_uint), + ('smUtil', c_uint), + ('memUtil', c_uint), + ('encUtil', c_uint), + ('decUtil', c_uint), + ('jpgUtil', c_uint), + ('ofaUtil', c_uint), + ] + +class c_nvmlVgpuProcessesUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('vgpuProcessCount', c_uint), + ('lastSeenTimeStamp', c_ulonglong), + ('vgpuProcUtilArray', POINTER(c_nvmlVgpuProcessUtilizationInfo_v1_t)), + ] + +VgpuProcessesUtilizationInfo_v1 = 0x01000018 + +class nvmlVgpuRuntimeState_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('size', c_ulonglong), + ] + +VgpuRuntimeState_v1 = 0x1000010 + +class c_nvmlVgpuLicenseExpiry_t(_PrintableStructure): + _fields_ = [ + ('year', c_uint32), + ('month', c_uint16), + ('day', c_uint16), + ('hour', c_uint16), + ('min', c_uint16), + ('sec', c_uint16), + ('status', c_uint8), + ] + +NVML_GRID_LICENSE_STATE_UNKNOWN = 0 +NVML_GRID_LICENSE_STATE_UNINITIALIZED = 1 +NVML_GRID_LICENSE_STATE_UNLICENSED_UNRESTRICTED = 2 +NVML_GRID_LICENSE_STATE_UNLICENSED_RESTRICTED = 3 +NVML_GRID_LICENSE_STATE_UNLICENSED = 4 +NVML_GRID_LICENSE_STATE_LICENSED = 5 + +class c_nvmlVgpuLicenseInfo_t(_PrintableStructure): + _fields_ = [ + ('isLicensed', c_uint8), + ('licenseExpiry', c_nvmlVgpuLicenseExpiry_t), + ('currentState', c_uint), + ] + +class c_nvmlEncoderSession_t(_PrintableStructure): + _fields_ = [ + ('sessionId', c_uint), + ('pid', c_uint), + ('vgpuInstance', _nvmlVgpuInstance_t), + ('codecType', c_uint), + ('hResolution', c_uint), + ('vResolution', c_uint), + ('averageFps', c_uint), + ('encodeLatency', c_uint), + ] + +class c_nvmlProcessUtilizationSample_t(_PrintableStructure): + _fields_ = [ + ('pid', c_uint), + ('timeStamp', c_ulonglong), + ('smUtil', c_uint), + ('memUtil', c_uint), + ('encUtil', c_uint), + ('decUtil', c_uint), + ] + +class c_nvmlProcessUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('timeStamp', c_ulonglong), + ('pid', c_uint), + ('smUtil', c_uint), + ('memUtil', c_uint), + ('encUtil', c_uint), + ('decUtil', c_uint), + ('jpgUtil', c_uint), + ('ofaUtil', c_uint), + ] + +class c_nvmlProcessesUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('processSamplesCount', c_uint), + ('lastSeenTimeStamp', c_ulonglong), + ('procUtilArray', POINTER(c_nvmlProcessUtilizationInfo_v1_t)), + ] + +ProcessesUtilizationInfo_v1 = 0x01000018 + +class c_nvmlGridLicenseExpiry_t(_PrintableStructure): + _fields_ = [ + ('year', c_uint32), + ('month', c_uint16), + ('day', c_uint16), + ('hour', c_uint16), + ('min', c_uint16), + ('sec', c_uint16), + ('status', c_uint8), + ] + +class c_nvmlGridLicensableFeature_v4_t(_PrintableStructure): + _fields_ = [ + ('featureCode', _nvmlGridLicenseFeatureCode_t), + ('featureState', c_uint), + ('licenseInfo', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ('productName', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ('featureEnabled', c_uint), + ('licenseExpiry', c_nvmlGridLicenseExpiry_t), + ] + +class c_nvmlGridLicensableFeatures_v4_t(_PrintableStructure): + _fields_ = [ + ('isGridLicenseSupported', c_int), + ('licensableFeaturesCount', c_uint), + ('gridLicensableFeatures', c_nvmlGridLicensableFeature_v4_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT), + ] + +class c_nvmlGridLicensableFeature_v3_t(_PrintableStructure): + _fields_ = [ + ('featureCode', _nvmlGridLicenseFeatureCode_t), + ('featureState', c_uint), + ('licenseInfo', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ('productName', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ('featureEnabled', c_uint), + ] + +class c_nvmlGridLicensableFeatures_v3_t(_PrintableStructure): + _fields_ = [ + ('isGridLicenseSupported', c_int), + ('licensableFeaturesCount', c_uint), + ('gridLicensableFeatures', c_nvmlGridLicensableFeature_v3_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT), + ] + +class c_nvmlGridLicensableFeature_v2_t(_PrintableStructure): + _fields_ = [ + ('featureCode', _nvmlGridLicenseFeatureCode_t), + ('featureState', c_uint), + ('licenseInfo', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ('productName', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ] + +class c_nvmlGridLicensableFeatures_v2_t(_PrintableStructure): + _fields_ = [ + ('isGridLicenseSupported', c_int), + ('licensableFeaturesCount', c_uint), + ('gridLicensableFeatures', c_nvmlGridLicensableFeature_v2_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT), + ] + +class c_nvmlGridLicensableFeature_t(_PrintableStructure): + _fields_ = [ + ('featureCode', _nvmlGridLicenseFeatureCode_t), + ('featureState', c_uint), + ('licenseInfo', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ] + +class c_nvmlGridLicensableFeatures_t(_PrintableStructure): + _fields_ = [ + ('isGridLicenseSupported', c_int), + ('licensableFeaturesCount', c_uint), + ('gridLicensableFeatures', c_nvmlGridLicensableFeature_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT), + ] + +class c_nvmlMarginTemperature_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('marginTemperature', c_int), + ] + +nvmlMarginTemperature_v1 = 0x1000008 + +## Event structures +class struct_c_nvmlEventSet_t(Structure): + pass # opaque handle +c_nvmlEventSet_t = POINTER(struct_c_nvmlEventSet_t) + +nvmlEventTypeSingleBitEccError = 0x0000000000000001 +nvmlEventTypeDoubleBitEccError = 0x0000000000000002 +nvmlEventTypePState = 0x0000000000000004 +nvmlEventTypeXidCriticalError = 0x0000000000000008 +nvmlEventTypeClock = 0x0000000000000010 +nvmlEventTypePowerSourceChange = 0x0000000000000080 +nvmlEventMigConfigChange = 0x0000000000000100 +nvmlEventTypeSingleBitEccErrorStorm = 0x0000000000000200 +nvmlEventTypeDramRetirementEvent = 0x0000000000000400 +nvmlEventTypeDramRetirementFailure = 0x0000000000000800 +nvmlEventTypeNonFatalPoisonError = 0x0000000000001000 +nvmlEventTypeFatalPoisonError = 0x0000000000002000 +nvmlEventTypeGpuUnavailableError = 0x0000000000004000 +nvmlEventTypeGpuRecoveryAction = 0x0000000000008000 +nvmlEventTypeNone = 0x0000000000000000 +nvmlEventTypeAll = ( + nvmlEventTypeNone + | nvmlEventTypeSingleBitEccError + | nvmlEventTypeDoubleBitEccError + | nvmlEventTypePState + | nvmlEventTypeClock + | nvmlEventTypePowerSourceChange + | nvmlEventTypeXidCriticalError + | nvmlEventMigConfigChange + | nvmlEventTypeSingleBitEccErrorStorm + | nvmlEventTypeDramRetirementEvent + | nvmlEventTypeDramRetirementFailure + | nvmlEventTypeNonFatalPoisonError + | nvmlEventTypeFatalPoisonError + | nvmlEventTypeGpuUnavailableError + | nvmlEventTypeGpuRecoveryAction + ) + +## Clock Event Reasons defines +nvmlClocksEventReasonGpuIdle = 0x0000000000000001 +nvmlClocksEventReasonApplicationsClocksSetting = 0x0000000000000002 +nvmlClocksEventReasonUserDefinedClocks = nvmlClocksEventReasonApplicationsClocksSetting # deprecated, use nvmlClocksEventReasonApplicationsClocksSetting +nvmlClocksEventReasonSwPowerCap = 0x0000000000000004 +nvmlClocksEventReasonHwSlowdown = 0x0000000000000008 +nvmlClocksEventReasonSyncBoost = 0x0000000000000010 +nvmlClocksEventReasonSwThermalSlowdown = 0x0000000000000020 +nvmlClocksEventReasonHwThermalSlowdown = 0x0000000000000040 +nvmlClocksEventReasonHwPowerBrakeSlowdown = 0x0000000000000080 +nvmlClocksEventReasonDisplayClockSetting = 0x0000000000000100 +nvmlClocksEventReasonNone = 0x0000000000000000 +nvmlClocksEventReasonAll = ( + nvmlClocksEventReasonNone | + nvmlClocksEventReasonGpuIdle | + nvmlClocksEventReasonApplicationsClocksSetting | + nvmlClocksEventReasonSwPowerCap | + nvmlClocksEventReasonHwSlowdown | + nvmlClocksEventReasonSyncBoost | + nvmlClocksEventReasonSwThermalSlowdown | + nvmlClocksEventReasonHwThermalSlowdown | + nvmlClocksEventReasonHwPowerBrakeSlowdown | + nvmlClocksEventReasonDisplayClockSetting + ) + +## Following have been deprecated +nvmlClocksThrottleReasonGpuIdle = 0x0000000000000001 +nvmlClocksThrottleReasonApplicationsClocksSetting = 0x0000000000000002 +nvmlClocksThrottleReasonUserDefinedClocks = nvmlClocksThrottleReasonApplicationsClocksSetting # deprecated, use nvmlClocksThrottleReasonApplicationsClocksSetting +nvmlClocksThrottleReasonSwPowerCap = 0x0000000000000004 +nvmlClocksThrottleReasonHwSlowdown = 0x0000000000000008 +nvmlClocksThrottleReasonSyncBoost = 0x0000000000000010 +nvmlClocksThrottleReasonSwThermalSlowdown = 0x0000000000000020 +nvmlClocksThrottleReasonHwThermalSlowdown = 0x0000000000000040 +nvmlClocksThrottleReasonHwPowerBrakeSlowdown = 0x0000000000000080 +nvmlClocksThrottleReasonDisplayClockSetting = 0x0000000000000100 +nvmlClocksThrottleReasonNone = 0x0000000000000000 +nvmlClocksThrottleReasonAll = ( + nvmlClocksThrottleReasonNone | + nvmlClocksThrottleReasonGpuIdle | + nvmlClocksThrottleReasonApplicationsClocksSetting | + nvmlClocksThrottleReasonSwPowerCap | + nvmlClocksThrottleReasonHwSlowdown | + nvmlClocksThrottleReasonSyncBoost | + nvmlClocksThrottleReasonSwThermalSlowdown | + nvmlClocksThrottleReasonHwThermalSlowdown | + nvmlClocksThrottleReasonHwPowerBrakeSlowdown | + nvmlClocksThrottleReasonDisplayClockSetting + ) + +class c_nvmlEventData_t(_PrintableStructure): + _fields_ = [ + ('device', c_nvmlDevice_t), + ('eventType', c_ulonglong), + ('eventData', c_ulonglong), + ('gpuInstanceId', c_uint), + ('computeInstanceId', c_uint) + ] + _fmt_ = {'eventType': "0x%08X"} + +class c_nvmlAccountingStats_t(_PrintableStructure): + _fields_ = [ + ('gpuUtilization', c_uint), + ('memoryUtilization', c_uint), + ('maxMemoryUsage', c_ulonglong), + ('time', c_ulonglong), + ('startTime', c_ulonglong), + ('isRunning', c_uint), + ('reserved', c_uint * 5) + ] + +class c_nvmlVgpuVersion_t(Structure): + _fields_ = [("minVersion", c_uint), + ("maxVersion", c_uint) + ] + +class c_nvmlVgpuMetadata_t(_PrintableStructure): + _fields_ = [("version", c_uint), + ("revision", c_uint), + ("guestInfoState", _nvmlVgpuGuestInfoState_t), + ("guestDriverVersion", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ("hostDriverVersion", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ("reserved", c_uint * 6), + ("vgpuVirtualizationCaps", c_uint), + ("guestVgpuVersion", c_uint), + ("opaqueDataSize", c_uint), + ("opaqueData", c_char * NVML_VGPU_METADATA_OPAQUE_DATA_SIZE) + ] + +class c_nvmlVgpuPgpuMetadata_t(_PrintableStructure): + _fields_ = [("version", c_uint), + ("revision", c_uint), + ("hostDriverVersion", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ("pgpuVirtualizationCaps", c_uint), + ("reserved", c_uint * 5), + ("hostSupportedVgpuRange", c_nvmlVgpuVersion_t), + ("opaqueDataSize", c_uint), + ("opaqueData", c_char * NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE) + ] + +class c_nvmlVgpuPgpuCompatibility_t(Structure): + _fields_ = [("vgpuVmCompatibility", _nvmlVgpuVmCompatibility_t), + ("compatibilityLimitCode", _nvmlVgpuPgpuCompatibilityLimitCode_t) + ] + +## vGPU scheduler policy defines +NVML_VGPU_SCHEDULER_POLICY_UNKNOWN = 0 +NVML_VGPU_SCHEDULER_POLICY_BEST_EFFORT = 1 +NVML_VGPU_SCHEDULER_POLICY_EQUAL_SHARE = 2 +NVML_VGPU_SCHEDULER_POLICY_FIXED_SHARE = 3 + +## Supported vGPU scheduler policy count +NVML_SUPPORTED_VGPU_SCHEDULER_POLICY_COUNT = 3 + +NVML_SCHEDULER_SW_MAX_LOG_ENTRIES = 200 + +NVML_VGPU_SCHEDULER_ARR_DEFAULT = 0 +NVML_VGPU_SCHEDULER_ARR_DISABLE = 1 +NVML_VGPU_SCHEDULER_ARR_ENABLE = 2 + +class c_nvmlVgpuSchedDataWithARR_t(_PrintableStructure): + _fields_ = [ + ('avgFactor', c_uint), + ('timeslice', c_uint), + ] + +class c_nvmlVgpuSchedData_t(_PrintableStructure): + _fields_ = [ + ('timeslice', c_uint), + ] + +class c_nvmlVgpuSchedulerParams_t(Union): + _fields_ = [ + ('vgpuSchedDataWithARR', c_nvmlVgpuSchedDataWithARR_t), + ('vgpuSchedData', c_nvmlVgpuSchedData_t), + ] + +class c_nvmlVgpuSchedulerLogEntry_t(_PrintableStructure): + _fields_ = [ + ('timestamp', c_ulonglong), + ('timeRunTotal', c_ulonglong), + ('timeRun', c_ulonglong), + ('swRunlistId', c_uint), + ('targetTimeSlice', c_ulonglong), + ('cumulativePreemptionTime', c_ulonglong), + ] + +class c_nvmlVgpuSchedulerLog_t(_PrintableStructure): + _fields_ = [ + ('engineId', c_uint), + ('schedulerPolicy', c_uint), + ('arrMode', c_uint), + ('schedulerParams', c_nvmlVgpuSchedulerParams_t), + ('entriesCount', c_uint), + ('logEntries', c_nvmlVgpuSchedulerLogEntry_t * NVML_SCHEDULER_SW_MAX_LOG_ENTRIES), + ] + +class c_nvmlVgpuSchedulerGetState_t(_PrintableStructure): + _fields_ = [ + ('schedulerPolicy', c_uint), + ('arrMode', c_uint), + ('schedulerParams', c_nvmlVgpuSchedulerParams_t), + ] + +class c_nvmlVgpuSchedSetDataWithARR_t(_PrintableStructure): + _fields_ = [ + ('avgFactor', c_uint), + ('frequency', c_uint), + ] + +class c_nvmlVgpuSchedSetData_t(_PrintableStructure): + _fields_ = [ + ('timeslice', c_uint), + ] + +class c_nvmlVgpuSchedulerSetParams_t(Union): + _fields_ = [ + ('vgpuSchedDataWithARR', c_nvmlVgpuSchedSetDataWithARR_t), + ('vgpuSchedData', c_nvmlVgpuSchedSetData_t), + ] + +class c_nvmlVgpuSchedulerSetState_t(_PrintableStructure): + _fields_ = [ + ('schedulerPolicy', c_uint), + ('enableARRMode', c_uint), + ('schedulerParams', c_nvmlVgpuSchedulerSetParams_t), + ] + +class c_nvmlVgpuSchedulerCapabilities_t(_PrintableStructure): + _fields_ = [ + ('supportedSchedulers', c_uint * NVML_SUPPORTED_VGPU_SCHEDULER_POLICY_COUNT), + ('maxTimeslice', c_uint), + ('minTimeslice', c_uint), + ('isArrModeSupported', c_uint), + ('maxFrequencyForARR', c_uint), + ('minFrequencyForARR', c_uint), + ('maxAvgFactorForARR', c_uint), + ('minAvgFactorForARR', c_uint), + ] + +class c_nvmlFBCStats_t(Structure): + _fields_ = [("sessionsCount", c_uint), + ("averageFPS", c_uint), + ("averageLatency", c_uint) + ] + +class c_nvmlFBCSession_t(_PrintableStructure): + _fields_ = [ + ('sessionId', c_uint), + ('pid', c_uint), + ('vgpuInstance', _nvmlVgpuInstance_t), + ('displayOrdinal', c_uint), + ('sessionType', c_uint), + ('sessionFlags', c_uint), + ('hMaxResolution', c_uint), + ('vMaxResolution', c_uint), + ('hResolution', c_uint), + ('vResolution', c_uint), + ('averageFPS', c_uint), + ('averageLatency', c_uint), + ] + +NVML_DEVICE_MIG_DISABLE = 0x0 +NVML_DEVICE_MIG_ENABLE = 0x1 + +NVML_GPU_INSTANCE_PROFILE_1_SLICE = 0x0 +NVML_GPU_INSTANCE_PROFILE_2_SLICE = 0x1 +NVML_GPU_INSTANCE_PROFILE_3_SLICE = 0x2 +NVML_GPU_INSTANCE_PROFILE_4_SLICE = 0x3 +NVML_GPU_INSTANCE_PROFILE_7_SLICE = 0x4 +NVML_GPU_INSTANCE_PROFILE_8_SLICE = 0x5 +NVML_GPU_INSTANCE_PROFILE_6_SLICE = 0x6 +NVML_GPU_INSTANCE_PROFILE_1_SLICE_REV1 = 0x7 +NVML_GPU_INSTANCE_PROFILE_2_SLICE_REV1 = 0x8 +NVML_GPU_INSTANCE_PROFILE_1_SLICE_REV2 = 0x9 +NVML_GPU_INSTANCE_PROFILE_1_SLICE_GFX = 0xA +NVML_GPU_INSTANCE_PROFILE_2_SLICE_GFX = 0xB +NVML_GPU_INSTANCE_PROFILE_4_SLICE_GFX = 0xC +NVML_GPU_INSTANCE_PROFILE_COUNT = 0xD + +class c_nvmlGpuInstancePlacement_t(Structure): + _fields_ = [("start", c_uint), + ("size", c_uint) + ] + +class c_nvmlGpuInstanceProfileInfo_t(Structure): + _fields_ = [("id", c_uint), + ("isP2pSupported", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("copyEngineCount", c_uint), + ("decoderCount", c_uint), + ("encoderCount", c_uint), + ("jpegCount", c_uint), + ("ofaCount", c_uint), + ("memorySizeMB", c_ulonglong), + ] + +nvmlGpuInstanceProfileInfo_v2 = 0x02000098 + +class c_nvmlGpuInstanceProfileInfo_v2_t(_PrintableStructure): + _fields_ = [("version", c_uint), + ("id", c_uint), + ("isP2pSupported", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("copyEngineCount", c_uint), + ("decoderCount", c_uint), + ("encoderCount", c_uint), + ("jpegCount", c_uint), + ("ofaCount", c_uint), + ("memorySizeMB", c_ulonglong), + ("name", c_char * NVML_DEVICE_NAME_V2_BUFFER_SIZE) + ] + + def __init__(self): + super(c_nvmlGpuInstanceProfileInfo_v2_t, self).__init__(version=nvmlGpuInstanceProfileInfo_v2) + +class c_nvmlGpuInstanceInfo_t(Structure): + _fields_ = [("device", c_nvmlDevice_t), + ("id", c_uint), + ("profileId", c_uint), + ("placement", c_nvmlGpuInstancePlacement_t) + ] + +class struct_c_nvmlGpuInstance_t(Structure): + pass # opaque handle +c_nvmlGpuInstance_t = POINTER(struct_c_nvmlGpuInstance_t) + +NVML_COMPUTE_INSTANCE_PROFILE_1_SLICE = 0x0 +NVML_COMPUTE_INSTANCE_PROFILE_2_SLICE = 0x1 +NVML_COMPUTE_INSTANCE_PROFILE_3_SLICE = 0x2 +NVML_COMPUTE_INSTANCE_PROFILE_4_SLICE = 0x3 +NVML_COMPUTE_INSTANCE_PROFILE_7_SLICE = 0x4 +NVML_COMPUTE_INSTANCE_PROFILE_8_SLICE = 0x5 +NVML_COMPUTE_INSTANCE_PROFILE_6_SLICE = 0x6 +NVML_COMPUTE_INSTANCE_PROFILE_1_SLICE_REV1 = 0x7 +NVML_COMPUTE_INSTANCE_PROFILE_COUNT = 0x8 + +NVML_COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED = 0x0 +NVML_COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT = 0x1 + +class c_nvmlComputeInstancePlacement_t(Structure): + _fields_ = [("start", c_uint), + ("size", c_uint) + ] + +class c_nvmlComputeInstanceProfileInfo_t(Structure): + _fields_ = [("id", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("sharedCopyEngineCount", c_uint), + ("sharedDecoderCount", c_uint), + ("sharedEncoderCount", c_uint), + ("sharedJpegCount", c_uint), + ("sharedOfaCount", c_uint) + ] + +nvmlComputeInstanceProfileInfo_v2 = 0x02000088 + +class c_nvmlComputeInstanceProfileInfo_v2_t(_PrintableStructure): + _fields_ = [("version", c_uint), + ("id", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("sharedCopyEngineCount", c_uint), + ("sharedDecoderCount", c_uint), + ("sharedEncoderCount", c_uint), + ("sharedJpegCount", c_uint), + ("sharedOfaCount", c_uint), + ("name", c_char * NVML_DEVICE_NAME_V2_BUFFER_SIZE) + ] + + def __init__(self): + super(c_nvmlComputeInstanceProfileInfo_v2_t, self).__init__(version=nvmlComputeInstanceProfileInfo_v2) + +class c_nvmlComputeInstanceInfo_t(Structure): + _fields_ = [("device", c_nvmlDevice_t), + ("gpuInstance", c_nvmlGpuInstance_t), + ("id", c_uint), + ("profileId", c_uint), + ("placement", c_nvmlComputeInstancePlacement_t) + ] + +NVML_MAX_GPU_UTILIZATIONS = 8 +NVML_GPU_UTILIZATION_DOMAIN_GPU = 0 +NVML_GPU_UTILIZATION_DOMAIN_FB = 1 +NVML_GPU_UTILIZATION_DOMAIN_VID = 2 +NVML_GPU_UTILIZATION_DOMAIN_BUS = 3 +class c_nvmlGpuDynamicPstatesUtilization_t(Structure): + _fields_ = [("bIsPresent", c_uint, 1), + ("percentage", c_uint), + ("incThreshold", c_uint), + ("decThreshold", c_uint)] +class c_nvmlGpuDynamicPstatesInfo_t(Structure): + _fields_ = [("flags", c_uint), + ("utilization", c_nvmlGpuDynamicPstatesUtilization_t * NVML_MAX_GPU_UTILIZATIONS)] + +NVML_MAX_THERMAL_SENSORS_PER_GPU = 3 + +NVML_THERMAL_TARGET_NONE = 0 +NVML_THERMAL_TARGET_GPU = 1 +NVML_THERMAL_TARGET_MEMORY = 2 +NVML_THERMAL_TARGET_POWER_SUPPLY = 4 +NVML_THERMAL_TARGET_BOARD = 8 +NVML_THERMAL_TARGET_VCD_BOARD = 9 +NVML_THERMAL_TARGET_VCD_INLET = 10 +NVML_THERMAL_TARGET_VCD_OUTLET = 11 +NVML_THERMAL_TARGET_ALL = 15 +NVML_THERMAL_TARGET_UNKNOWN = -1 + +NVML_THERMAL_CONTROLLER_NONE = 0 +NVML_THERMAL_CONTROLLER_GPU_INTERNAL = 1 +NVML_THERMAL_CONTROLLER_ADM1032 = 2 +NVML_THERMAL_CONTROLLER_ADT7461 = 3 +NVML_THERMAL_CONTROLLER_MAX6649 = 4 +NVML_THERMAL_CONTROLLER_MAX1617 = 5 +NVML_THERMAL_CONTROLLER_LM99 = 6 +NVML_THERMAL_CONTROLLER_LM89 = 7 +NVML_THERMAL_CONTROLLER_LM64 = 8 +NVML_THERMAL_CONTROLLER_G781 = 9 +NVML_THERMAL_CONTROLLER_ADT7473 = 10 +NVML_THERMAL_CONTROLLER_SBMAX6649 = 11 +NVML_THERMAL_CONTROLLER_VBIOSEVT = 12 +NVML_THERMAL_CONTROLLER_OS = 13 +NVML_THERMAL_CONTROLLER_NVSYSCON_CANOAS = 14 +NVML_THERMAL_CONTROLLER_NVSYSCON_E551 = 15 +NVML_THERMAL_CONTROLLER_MAX6649R = 16 +NVML_THERMAL_CONTROLLER_ADT7473S = 17 +NVML_THERMAL_CONTROLLER_UNKNOWN = -1 + +class c_nvmlGpuThermalSensor_t(Structure): + _fields_ = [("controller", c_int), + ("defaultMinTemp", c_int), + ("defaultMaxTemp", c_int), + ("currentTemp", c_int), + ("target", c_int)] +class c_nvmlGpuThermalSettings_t(Structure): + _fields_ = [("count", c_uint), + ("sensor", c_nvmlGpuThermalSensor_t * NVML_MAX_THERMAL_SENSORS_PER_GPU)] + +_nvmlCoolerControl_t = c_uint +NVML_THERMAL_COOLER_SIGNAL_NONE = 0 +NVML_THERMAL_COOLER_SIGNAL_TOGGLE = 1 +NVML_THERMAL_COOLER_SIGNAL_VARIABLE = 2 +NVML_THERMAL_COOLER_SIGNAL_COUNT = 3 + +_nvmlCoolerTarget_t = c_uint +NVML_THERMAL_COOLER_TARGET_NONE = (1 << 0) +NVML_THERMAL_COOLER_TARGET_GPU = (1 << 1) +NVML_THERMAL_COOLER_TARGET_MEMORY = (1 << 2) +NVML_THERMAL_COOLER_TARGET_POWER_SUPPLY = (1 << 3) +NVML_THERMAL_COOLER_TARGET_GPU_RELATED = (NVML_THERMAL_COOLER_TARGET_GPU | NVML_THERMAL_COOLER_TARGET_MEMORY | NVML_THERMAL_COOLER_TARGET_POWER_SUPPLY) + +class c_nvmlCoolerInfo_t(_PrintableStructure): + _fields_ = [("version", c_uint), + ("index", c_uint), + ("coolerControlType", _nvmlCoolerControl_t), + ("coolerTarget", _nvmlCoolerTarget_t) + ] + +nvmlCoolerInfo_v1 = 0x1000010 + +def nvmlDeviceGetCoolerInfo(handle): + c_coolerInfo = c_nvmlCoolerInfo_t() + c_coolerInfo.version = nvmlCoolerInfo_v1 + c_coolerInfo.index = 0 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCoolerInfo") + ret = fn(handle, byref(c_coolerInfo)) + _nvmlCheckReturn(ret) + return [c_coolerInfo.coolerControlType, c_coolerInfo.coolerTarget] + +class struct_c_nvmlComputeInstance_t(Structure): + pass # opaque handle +c_nvmlComputeInstance_t = POINTER(struct_c_nvmlComputeInstance_t) + +class c_nvmlDeviceAttributes(Structure): + _fields_ = [("multiprocessorCount", c_uint), + ("sharedCopyEngineCount", c_uint), + ("sharedDecoderCount", c_uint), + ("sharedEncoderCount", c_uint), + ("sharedJpegCount", c_uint), + ("sharedOfaCount", c_uint), + ("gpuInstanceSliceCount", c_uint), + ("computeInstanceSliceCount", c_uint), + ("memorySizeMB", c_ulonglong), + ] + +class c_nvmlRowRemapperHistogramValues(Structure): + _fields_ = [("max", c_uint), + ("high", c_uint), + ("partial", c_uint), + ("low", c_uint), + ("none", c_uint) + ] + +NVML_GPU_CERT_CHAIN_SIZE = 0x1000 +NVML_GPU_ATTESTATION_CERT_CHAIN_SIZE = 0x1400 +NVML_CC_GPU_CEC_NONCE_SIZE = 0x20 +NVML_CC_GPU_ATTESTATION_REPORT_SIZE = 0x2000 +NVML_CC_GPU_CEC_ATTESTATION_REPORT_SIZE = 0x1000 +NVML_CC_CEC_ATTESTATION_REPORT_NOT_PRESENT = 0 +NVML_CC_CEC_ATTESTATION_REPORT_PRESENT = 1 + +class c_nvmlConfComputeSystemState_t(Structure): + _fields_ = [('environment', c_uint), + ('ccFeature', c_uint), + ('devToolsMode', c_uint), + ] + +nvmlSystemConfComputeSettings_v1 = 0x1000014 + +class c_nvmlSystemConfComputeSettings_v1_t(Structure): + _fields_ = [('version', c_uint), + ('environment', c_uint), + ('ccFeature', c_uint), + ('devToolsMode', c_uint), + ('multiGpuMode', c_uint), + ] + def __init__(self): + super(c_nvmlSystemConfComputeSettings_v1_t, self).__init__(version=nvmlSystemConfComputeSettings_v1) + +class c_nvmlConfComputeSystemCaps_t(Structure): + _fields_ = [('cpuCaps', c_uint), + ('gpusCaps', c_uint), + ] + +class c_nvmlConfComputeMemSizeInfo_t(Structure): + _fields_ = [('protectedMemSizeKib', c_ulonglong), + ('unprotectedMemSizeKib', c_ulonglong), + ] + +class c_nvmlConfComputeGpuCertificate_t(Structure): + _fields_ = [('certChainSize', c_uint), + ('attestationCertChainSize', c_uint), + ('certChain', c_uint8 * NVML_GPU_CERT_CHAIN_SIZE), + ('attestationCertChain', c_uint8 * NVML_GPU_ATTESTATION_CERT_CHAIN_SIZE), + ] + +class c_nvmlConfComputeGpuAttestationReport_t(Structure): + _fields_ = [('isCecAttestationReportPresent', c_uint), + ('attestationReportSize', c_uint), + ('cecAttestationReportSize', c_uint), + ('nonce', c_uint8 * NVML_CC_GPU_CEC_NONCE_SIZE), + ('attestationReport', c_uint8 * NVML_CC_GPU_ATTESTATION_REPORT_SIZE), + ('cecAttestationReport', c_uint8 * NVML_CC_GPU_CEC_ATTESTATION_REPORT_SIZE), + ] + +class c_nvmlConfComputeSetKeyRotationThresholdInfo_t(Structure): + _fields_ = [('version', c_uint), + ('maxAttackerAdvantage', c_ulong), + ] +ConfComputeSetKeyRotationThresholdInfo_v1 = 0x1000010 + +class c_nvmlConfComputeGetKeyRotationThresholdInfo_t(Structure): + _fields_ = [('version', c_uint), + ('attackerAdvantage', c_ulong), + ] +ConfComputeGetKeyRotationThresholdInfo_v1 = 0x1000010 + + +## string/bytes conversion for ease of use +def convertStrBytes(func): + ''' + In python 3, strings are unicode instead of bytes, and need to be converted for ctypes + Args from caller: (1, 'string', <__main__.c_nvmlDevice_t at 0xFFFFFFFF>) + Args passed to function: (1, b'string', <__main__.c_nvmlDevice_t at 0xFFFFFFFF)> + ---- + Returned from function: b'returned string' + Returned to caller: 'returned string' + ''' + @wraps(func) + def wrapper(*args, **kwargs): + # encoding a str returns bytes in python 2 and 3 + args = [arg.encode() if isinstance(arg, str) else arg for arg in args] + res = func(*args, **kwargs) + # In python 2, str and bytes are the same + # In python 3, str is unicode and should be decoded. + # Ctypes handles most conversions, this only effects c_char and char arrays. + if isinstance(res, bytes): + if isinstance(res, str): + return res + return res.decode() + return res + + if sys.version_info >= (3,): + return wrapper + return func + +def throwOnVersionMismatch(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except NVMLError_FunctionNotFound: + raise NVMLLibraryMismatchError("Unversioned function called and the " + "pyNVML version does not match the NVML lib version. " + "Either use matching pyNVML and NVML lib versions or " + "use a versioned function such as " + func.__name__ + "_v2") + return wrapper + +## C function wrappers ## +def nvmlInitWithFlags(flags): + _LoadNvmlLibrary() + + # + # Initialize the library + # + fn = _nvmlGetFunctionPointer("nvmlInitWithFlags") + ret = fn(flags) + _nvmlCheckReturn(ret) + + # Atomically update refcount + global _nvmlLib_refcount + libLoadLock.acquire() + _nvmlLib_refcount += 1 + libLoadLock.release() + return None + +def nvmlInit(): + nvmlInitWithFlags(0) + return None + +def _LoadNvmlLibrary(): + ''' + Load the library if it isn't loaded already + ''' + global nvmlLib + + if (nvmlLib == None): + # lock to ensure only one caller loads the library + libLoadLock.acquire() + + try: + # ensure the library still isn't loaded + if (nvmlLib == None): + try: + if (sys.platform[:3] == "win"): + # cdecl calling convention + try: + # Check for nvml.dll in System32 first for DCH drivers + nvmlLib = CDLL(os.path.join(os.getenv("WINDIR", "C:/Windows"), "System32/nvml.dll")) + except OSError as ose: + # If nvml.dll is not found in System32, it should be in ProgramFiles + # load nvml.dll from %ProgramFiles%/NVIDIA Corporation/NVSMI/nvml.dll + nvmlLib = CDLL(os.path.join(os.getenv("ProgramFiles", "C:/Program Files"), "NVIDIA Corporation/NVSMI/nvml.dll")) + else: + # assume linux + nvmlLib = CDLL("libnvidia-ml.so.1") + except OSError as ose: + _nvmlCheckReturn(NVML_ERROR_LIBRARY_NOT_FOUND) + if (nvmlLib == None): + _nvmlCheckReturn(NVML_ERROR_LIBRARY_NOT_FOUND) + finally: + # lock is always freed + libLoadLock.release() + +def nvmlShutdown(): + # + # Leave the library loaded, but shutdown the interface + # + fn = _nvmlGetFunctionPointer("nvmlShutdown") + ret = fn() + _nvmlCheckReturn(ret) + + # Atomically update refcount + global _nvmlLib_refcount + libLoadLock.acquire() + if (0 < _nvmlLib_refcount): + _nvmlLib_refcount -= 1 + libLoadLock.release() + return None + +# Added in 2.285 +@convertStrBytes +def nvmlErrorString(result): + fn = _nvmlGetFunctionPointer("nvmlErrorString") + fn.restype = c_char_p # otherwise return is an int + ret = fn(result) + return ret + +# Added in 2.285 +@convertStrBytes +def nvmlSystemGetNVMLVersion(): + c_version = create_string_buffer(NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlSystemGetNVMLVersion") + ret = fn(c_version, c_uint(NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + +def nvmlSystemGetCudaDriverVersion(): + c_cuda_version = c_int() + fn = _nvmlGetFunctionPointer("nvmlSystemGetCudaDriverVersion") + ret = fn(byref(c_cuda_version)) + _nvmlCheckReturn(ret) + return c_cuda_version.value + +def nvmlSystemGetCudaDriverVersion_v2(): + c_cuda_version = c_int() + fn = _nvmlGetFunctionPointer("nvmlSystemGetCudaDriverVersion_v2") + ret = fn(byref(c_cuda_version)) + _nvmlCheckReturn(ret) + return c_cuda_version.value + +# Added in 2.285 +@convertStrBytes +def nvmlSystemGetProcessName(pid): + c_name = create_string_buffer(1024) + fn = _nvmlGetFunctionPointer("nvmlSystemGetProcessName") + ret = fn(c_uint(pid), c_name, c_uint(1024)) + _nvmlCheckReturn(ret) + return c_name.value + +@convertStrBytes +def nvmlSystemGetDriverVersion(): + c_version = create_string_buffer(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlSystemGetDriverVersion") + ret = fn(c_version, c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + +# Added in 2.285 +def nvmlSystemGetHicVersion(): + c_count = c_uint(0) + hics = None + fn = _nvmlGetFunctionPointer("nvmlSystemGetHicVersion") + + # get the count + ret = fn(byref(c_count), None) + + # this should only fail with insufficient size + if ((ret != NVML_SUCCESS) and + (ret != NVML_ERROR_INSUFFICIENT_SIZE)): + raise NVMLError(ret) + + # If there are no hics + if (c_count.value == 0): + return [] + + hic_array = c_nvmlHwbcEntry_t * c_count.value + hics = hic_array() + ret = fn(byref(c_count), hics) + _nvmlCheckReturn(ret) + return hics + +def nvmlSystemGetDriverBranch(): + c_branchInfo = c_nvmlSystemDriverBranchInfo_v1_t(0) + c_branchInfo.version = SystemDriverBranchInfo_v1 + fn = _nvmlGetFunctionPointer("nvmlSystemGetDriverBranch") + ret = fn(byref(c_branchInfo), c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_branchInfo + +## Unit get functions +def nvmlUnitGetCount(): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlUnitGetCount") + ret = fn(byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlUnitGetHandleByIndex(index): + c_index = c_uint(index) + unit = c_nvmlUnit_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetHandleByIndex") + ret = fn(c_index, byref(unit)) + _nvmlCheckReturn(ret) + return unit + +def nvmlUnitGetUnitInfo(unit): + c_info = c_nvmlUnitInfo_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetUnitInfo") + ret = fn(unit, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlUnitGetLedState(unit): + c_state = c_nvmlLedState_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetLedState") + ret = fn(unit, byref(c_state)) + _nvmlCheckReturn(ret) + return c_state + +def nvmlUnitGetPsuInfo(unit): + c_info = c_nvmlPSUInfo_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetPsuInfo") + ret = fn(unit, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlUnitGetTemperature(unit, type): + c_temp = c_uint() + fn = _nvmlGetFunctionPointer("nvmlUnitGetTemperature") + ret = fn(unit, c_uint(type), byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.value + +def nvmlUnitGetFanSpeedInfo(unit): + c_speeds = c_nvmlUnitFanSpeeds_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetFanSpeedInfo") + ret = fn(unit, byref(c_speeds)) + _nvmlCheckReturn(ret) + return c_speeds + +# added to API +def nvmlUnitGetDeviceCount(unit): + c_count = c_uint(0) + # query the unit to determine device count + fn = _nvmlGetFunctionPointer("nvmlUnitGetDevices") + ret = fn(unit, byref(c_count), None) + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + ret = NVML_SUCCESS + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlUnitGetDevices(unit): + c_count = c_uint(nvmlUnitGetDeviceCount(unit)) + device_array = c_nvmlDevice_t * c_count.value + c_devices = device_array() + fn = _nvmlGetFunctionPointer("nvmlUnitGetDevices") + ret = fn(unit, byref(c_count), c_devices) + _nvmlCheckReturn(ret) + return c_devices + +## Device get functions +def nvmlDeviceGetCount(): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCount_v2") + ret = fn(byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlDeviceGetHandleByIndex(index): + c_index = c_uint(index) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleByIndex_v2") + ret = fn(c_index, byref(device)) + _nvmlCheckReturn(ret) + return device + +@convertStrBytes +def nvmlDeviceGetHandleBySerial(serial): + c_serial = c_char_p(serial) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleBySerial") + ret = fn(c_serial, byref(device)) + _nvmlCheckReturn(ret) + return device + +@convertStrBytes +def nvmlDeviceGetHandleByUUID(uuid): + c_uuid = c_char_p(uuid) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleByUUID") + ret = fn(c_uuid, byref(device)) + _nvmlCheckReturn(ret) + return device + +@convertStrBytes +def nvmlDeviceGetHandleByPciBusId(pciBusId): + c_busId = c_char_p(pciBusId) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleByPciBusId_v2") + ret = fn(c_busId, byref(device)) + _nvmlCheckReturn(ret) + return device + +@convertStrBytes +def nvmlDeviceGetName(handle): + c_name = create_string_buffer(NVML_DEVICE_NAME_V2_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetName") + ret = fn(handle, c_name, c_uint(NVML_DEVICE_NAME_V2_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_name.value + +class c_nvmlDevicePerfModes_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('str', c_char * NVML_PERF_MODES_BUFFER_SIZE), + ] + +nvmlDevicePerfModes_v1 = 0x1000804 + +@convertStrBytes +def nvmlDeviceGetPerformanceModes(handle): + perfModes = c_nvmlDevicePerfModes_v1_t() + perfModes.version = nvmlDevicePerfModes_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPerformanceModes") + ret = fn(handle, byref(perfModes)) + _nvmlCheckReturn(ret) + return perfModes.str + +class c_nvmlDeviceCurrentClockFreqs_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('str', c_char * NVML_PERF_MODES_BUFFER_SIZE), + ] + +nvmlDeviceCurrentClockFreqs_v1 = 0x1000804 + +@convertStrBytes +def nvmlDeviceGetCurrentClockFreqs(handle): + currentClockFreqs = c_nvmlDeviceCurrentClockFreqs_v1_t() + currentClockFreqs.version = nvmlDeviceCurrentClockFreqs_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrentClockFreqs") + ret = fn(handle, byref(currentClockFreqs)) + _nvmlCheckReturn(ret) + return currentClockFreqs.str + +def nvmlDeviceGetBoardId(handle): + c_id = c_uint(); + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBoardId") + ret = fn(handle, byref(c_id)) + _nvmlCheckReturn(ret) + return c_id.value + +def nvmlDeviceGetMultiGpuBoard(handle): + c_multiGpu = c_uint(); + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMultiGpuBoard") + ret = fn(handle, byref(c_multiGpu)) + _nvmlCheckReturn(ret) + return c_multiGpu.value + +def nvmlDeviceGetBrand(handle): + c_type = _nvmlBrandType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBrand") + ret = fn(handle, byref(c_type)) + _nvmlCheckReturn(ret) + return c_type.value + +def nvmlDeviceGetC2cModeInfoV1(handle): + c_info = c_nvmlC2cModeInfo_v1_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetC2cModeInfoV") + ret = fn(handle, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlDeviceGetC2cModeInfoV(handle): + return nvmlDeviceGetC2cModeInfoV1(handle) + +@convertStrBytes +def nvmlDeviceGetBoardPartNumber(handle): + c_part_number = create_string_buffer(NVML_DEVICE_PART_NUMBER_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBoardPartNumber") + ret = fn(handle, c_part_number, c_uint(NVML_DEVICE_PART_NUMBER_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_part_number.value + +@convertStrBytes +def nvmlDeviceGetSerial(handle): + c_serial = create_string_buffer(NVML_DEVICE_SERIAL_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSerial") + ret = fn(handle, c_serial, c_uint(NVML_DEVICE_SERIAL_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_serial.value + +def nvmlDeviceGetModuleId(handle, moduleId=c_uint()): + isReference = type(moduleId) is not c_uint + moduleIdRef = moduleId if isReference else byref(moduleId) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetModuleId") + ret = fn(handle, moduleIdRef) + if isReference: + return ret + else: + _nvmlCheckReturn(ret) + return moduleId.value + +def nvmlDeviceGetMemoryAffinity(handle, nodeSetSize, scope): + affinity_array = c_ulonglong * nodeSetSize + c_affinity = affinity_array() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryAffinity") + ret = fn(handle, nodeSetSize, byref(c_affinity), _nvmlAffinityScope_t(scope)) + _nvmlCheckReturn(ret) + return c_affinity + +def nvmlDeviceGetCpuAffinityWithinScope(handle, cpuSetSize, scope): + affinity_array = c_ulonglong * cpuSetSize + c_affinity = affinity_array() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCpuAffinityWithinScope") + ret = fn(handle, cpuSetSize, byref(c_affinity), _nvmlAffinityScope_t(scope)) + _nvmlCheckReturn(ret) + return c_affinity + +def nvmlDeviceGetCpuAffinity(handle, cpuSetSize): + affinity_array = c_ulonglong * cpuSetSize + c_affinity = affinity_array() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCpuAffinity") + ret = fn(handle, cpuSetSize, byref(c_affinity)) + _nvmlCheckReturn(ret) + return c_affinity + +def nvmlDeviceSetCpuAffinity(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetCpuAffinity") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceClearCpuAffinity(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceClearCpuAffinity") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetNumaNodeId(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNumaNodeId") + node = c_int() + ret = fn(handle, byref(node)) + _nvmlCheckReturn(ret) + return node.value + +def nvmlDeviceGetMinorNumber(handle): + c_minor_number = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMinorNumber") + ret = fn(handle, byref(c_minor_number)) + _nvmlCheckReturn(ret) + return c_minor_number.value + +@convertStrBytes +def nvmlDeviceGetUUID(handle): + c_uuid = create_string_buffer(NVML_DEVICE_UUID_V2_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetUUID") + ret = fn(handle, c_uuid, c_uint(NVML_DEVICE_UUID_V2_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_uuid.value + +@convertStrBytes +def nvmlDeviceGetInforomVersion(handle, infoRomObject): + c_version = create_string_buffer(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomVersion") + ret = fn(handle, _nvmlInforomObject_t(infoRomObject), + c_version, c_uint(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + +# Added in 4.304 +@convertStrBytes +def nvmlDeviceGetInforomImageVersion(handle): + c_version = create_string_buffer(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomImageVersion") + ret = fn(handle, c_version, c_uint(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + +# Added in 4.304 +def nvmlDeviceGetInforomConfigurationChecksum(handle): + c_checksum = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomConfigurationChecksum") + ret = fn(handle, byref(c_checksum)) + _nvmlCheckReturn(ret) + return c_checksum.value + +# Added in 4.304 +def nvmlDeviceValidateInforom(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceValidateInforom") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetLastBBXFlushTime(handle): + c_timestamp = c_ulonglong() + c_durationUs = c_ulong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetLastBBXFlushTime") + ret = fn(handle, byref(c_timestamp), byref(c_durationUs)) + _nvmlCheckReturn(ret) + return [c_timestamp.value, c_durationUs.value] + +def nvmlDeviceGetDisplayMode(handle): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDisplayMode") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + +def nvmlDeviceGetDisplayActive(handle): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDisplayActive") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + + +def nvmlDeviceGetPersistenceMode(handle): + c_state = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPersistenceMode") + ret = fn(handle, byref(c_state)) + _nvmlCheckReturn(ret) + return c_state.value + +def nvmlDeviceGetPciInfoExt(handle, c_info): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPciInfoExt") + ret = fn(handle, c_info) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetPciInfo_v3(handle): + c_info = nvmlPciInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPciInfo_v3") + ret = fn(handle, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlDeviceGetPciInfo(handle): + return nvmlDeviceGetPciInfo_v3(handle) + +def nvmlDeviceGetClockInfo(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClockInfo") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +# Added in 2.285 +def nvmlDeviceGetMaxClockInfo(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxClockInfo") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +# Added in 4.304 +def nvmlDeviceGetApplicationsClock(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetApplicationsClock") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +def nvmlDeviceGetMaxCustomerBoostClock(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxCustomerBoostClock") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +def nvmlDeviceGetClock(handle, type, id): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClock") + ret = fn(handle, _nvmlClockType_t(type), _nvmlClockId_t(id), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +# Added in 5.319 +def nvmlDeviceGetDefaultApplicationsClock(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDefaultApplicationsClock") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +# Added in 4.304 +def nvmlDeviceGetSupportedMemoryClocks(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedMemoryClocks") + ret = fn(handle, byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no clocks + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + clocks_array = c_uint * c_count.value + c_clocks = clocks_array() + + # make the call again + ret = fn(handle, byref(c_count), c_clocks) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + procs.append(c_clocks[i]) + + return procs + else: + # error case + raise NVMLError(ret) + +# Added in 4.304 +def nvmlDeviceGetSupportedGraphicsClocks(handle, memoryClockMHz): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedGraphicsClocks") + ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no clocks + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + clocks_array = c_uint * c_count.value + c_clocks = clocks_array() + + # make the call again + ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), c_clocks) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + procs.append(c_clocks[i]) + + return procs + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetFanSpeed(handle): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanSpeed") + ret = fn(handle, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + +def nvmlDeviceGetFanSpeed_v2(handle, fan): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanSpeed_v2") + ret = fn(handle, fan, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + +class c_nvmlFanSpeedInfo_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('fan', c_uint), + ('speed', c_uint), + ] + +nvmlFanSpeedInfo_v1 = 0x100000C + +def nvmlDeviceGetFanSpeedRPM(handle): + c_fanSpeed = c_nvmlFanSpeedInfo_t() + c_fanSpeed.fan = 0 + c_fanSpeed.version = nvmlFanSpeedInfo_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanSpeedRPM") + ret = fn(handle, byref(c_fanSpeed)) + _nvmlCheckReturn(ret) + return c_fanSpeed.speed + +def nvmlDeviceGetTargetFanSpeed(handle, fan): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTargetFanSpeed") + ret = fn(handle, fan, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + +def nvmlDeviceGetNumFans(device): + c_numFans = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNumFans") + ret = fn(device, byref(c_numFans)) + _nvmlCheckReturn(ret) + return c_numFans.value + +def nvmlDeviceSetDefaultFanSpeed_v2(handle, index): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDefaultFanSpeed_v2"); + ret = fn(handle, index) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetMinMaxFanSpeed(handle, minSpeed=c_uint(), maxSpeed=c_uint()): + isReference = (type(minSpeed) is not c_uint) or (type(maxSpeed) is not c_uint) + minSpeedRef = minSpeed if isReference else byref(minSpeed) + maxSpeedRef = maxSpeed if isReference else byref(maxSpeed) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMinMaxFanSpeed") + ret = fn(handle, minSpeedRef, maxSpeedRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else [minSpeed.value, maxSpeed.value] + +def nvmlDeviceGetFanControlPolicy_v2(handle, fan, fanControlPolicy=c_uint()): + isReference = type(fanControlPolicy) is not c_uint + fanControlPolicyRef = fanControlPolicy if isReference else byref(fanControlPolicy) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanControlPolicy_v2") + ret = fn(handle, fan, fanControlPolicyRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else fanControlPolicy.value + +def nvmlDeviceSetFanControlPolicy(handle, fan, fanControlPolicy): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetFanControlPolicy") + ret = fn(handle, fan, _nvmlFanControlPolicy_t(fanControlPolicy)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +class c_nvmlTemperature_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('sensorType', _nvmlTemperatureSensors_t), + ('temperature', c_int), + ] +nvmlTemperature_v1 = 0x100000C + +def nvmlDeviceGetTemperatureV1(handle, sensor): + c_temp = c_nvmlTemperature_v1_t() + c_temp.version = nvmlTemperature_v1 + c_temp.sensorType = _nvmlTemperatureSensors_t(sensor) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperatureV") + ret = fn(handle, byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.temperature + +def nvmlDeviceGetTemperatureV(handle, sensor, version=nvmlTemperature_v1): + if version == nvmlTemperature_v1: + return nvmlDeviceGetTemperatureV1(handle, sensor) + else: + raise NVMLError(NVML_ERROR_ARGUMENT_VERSION_MISMATCH) + +# DEPRECATED use nvmlDeviceGetTemperatureV instead +def nvmlDeviceGetTemperature(handle, sensor): + c_temp = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperature") + ret = fn(handle, _nvmlTemperatureSensors_t(sensor), byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.value + +def nvmlDeviceGetTemperatureThreshold(handle, threshold): + c_temp = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperatureThreshold") + ret = fn(handle, _nvmlTemperatureThresholds_t(threshold), byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.value + +def nvmlDeviceSetTemperatureThreshold(handle, threshold, temp): + c_temp = c_uint() + c_temp.value = temp + fn = _nvmlGetFunctionPointer("nvmlDeviceSetTemperatureThreshold") + ret = fn(handle, _nvmlTemperatureThresholds_t(threshold), byref(c_temp)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetMarginTemperature(handle): + c_marginTempInfo = c_nvmlMarginTemperature_v1_t() + c_marginTempInfo.version = nvmlMarginTemperature_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMarginTemperature") + ret = fn(handle, byref(c_marginTempInfo)) + _nvmlCheckReturn(ret) + return c_marginTempInfo.marginTemperature + +# DEPRECATED use nvmlDeviceGetPerformanceState +def nvmlDeviceGetPowerState(handle): + c_pstate = _nvmlPstates_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerState") + ret = fn(handle, byref(c_pstate)) + _nvmlCheckReturn(ret) + return c_pstate.value + +def nvmlDeviceGetPerformanceState(handle): + c_pstate = _nvmlPstates_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPerformanceState") + ret = fn(handle, byref(c_pstate)) + _nvmlCheckReturn(ret) + return c_pstate.value + +def nvmlDeviceGetPowerManagementMode(handle): + c_pcapMode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementMode") + ret = fn(handle, byref(c_pcapMode)) + _nvmlCheckReturn(ret) + return c_pcapMode.value + +def nvmlDeviceGetPowerManagementLimit(handle): + c_limit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementLimit") + ret = fn(handle, byref(c_limit)) + _nvmlCheckReturn(ret) + return c_limit.value + +# Added in 4.304 +def nvmlDeviceGetPowerManagementLimitConstraints(handle): + c_minLimit = c_uint() + c_maxLimit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementLimitConstraints") + ret = fn(handle, byref(c_minLimit), byref(c_maxLimit)) + _nvmlCheckReturn(ret) + return [c_minLimit.value, c_maxLimit.value] + +# Added in 4.304 +def nvmlDeviceGetPowerManagementDefaultLimit(handle): + c_limit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementDefaultLimit") + ret = fn(handle, byref(c_limit)) + _nvmlCheckReturn(ret) + return c_limit.value + + +# Added in 331 +def nvmlDeviceGetEnforcedPowerLimit(handle): + c_limit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEnforcedPowerLimit") + ret = fn(handle, byref(c_limit)) + _nvmlCheckReturn(ret) + return c_limit.value + +def nvmlDeviceGetPowerUsage(handle): + c_watts = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerUsage") + ret = fn(handle, byref(c_watts)) + _nvmlCheckReturn(ret) + return c_watts.value + +def nvmlDeviceGetTotalEnergyConsumption(handle): + c_millijoules = c_uint64() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTotalEnergyConsumption") + ret = fn(handle, byref(c_millijoules)) + _nvmlCheckReturn(ret) + return c_millijoules.value + +# Added in 4.304 +def nvmlDeviceGetGpuOperationMode(handle): + c_currState = _nvmlGpuOperationMode_t() + c_pendingState = _nvmlGpuOperationMode_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuOperationMode") + ret = fn(handle, byref(c_currState), byref(c_pendingState)) + _nvmlCheckReturn(ret) + return [c_currState.value, c_pendingState.value] + +# Added in 4.304 +def nvmlDeviceGetCurrentGpuOperationMode(handle): + return nvmlDeviceGetGpuOperationMode(handle)[0] + +# Added in 4.304 +def nvmlDeviceGetPendingGpuOperationMode(handle): + return nvmlDeviceGetGpuOperationMode(handle)[1] + +def nvmlDeviceGetMemoryInfo(handle, version=None): + if not version: + c_memory = c_nvmlMemory_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryInfo") + else: + c_memory = c_nvmlMemory_v2_t() + c_memory.version = version + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryInfo_v2") + ret = fn(handle, byref(c_memory)) + _nvmlCheckReturn(ret) + return c_memory + +def nvmlDeviceGetBAR1MemoryInfo(handle): + c_bar1_memory = c_nvmlBAR1Memory_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBAR1MemoryInfo") + ret = fn(handle, byref(c_bar1_memory)) + _nvmlCheckReturn(ret) + return c_bar1_memory + +def nvmlDeviceGetComputeMode(handle): + c_mode = _nvmlComputeMode_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeMode") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + +def nvmlDeviceGetCudaComputeCapability(handle): + c_major = c_int() + c_minor = c_int() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCudaComputeCapability") + ret = fn(handle, byref(c_major), byref(c_minor)) + _nvmlCheckReturn(ret) + return (c_major.value, c_minor.value) + +def nvmlDeviceGetEccMode(handle): + c_currState = _nvmlEnableState_t() + c_pendingState = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEccMode") + ret = fn(handle, byref(c_currState), byref(c_pendingState)) + _nvmlCheckReturn(ret) + return [c_currState.value, c_pendingState.value] + +# added to API +def nvmlDeviceGetCurrentEccMode(handle): + return nvmlDeviceGetEccMode(handle)[0] + +# added to API +def nvmlDeviceGetPendingEccMode(handle): + return nvmlDeviceGetEccMode(handle)[1] + +def nvmlDeviceGetDefaultEccMode(handle): + c_defaultState = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDefaultEccMode") + ret = fn(handle, byref(c_defaultState)) + _nvmlCheckReturn(ret) + return [c_defaultState.value] + +def nvmlDeviceGetTotalEccErrors(handle, errorType, counterType): + c_count = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTotalEccErrors") + ret = fn(handle, _nvmlMemoryErrorType_t(errorType), + _nvmlEccCounterType_t(counterType), byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +# This is deprecated, instead use nvmlDeviceGetMemoryErrorCounter +def nvmlDeviceGetDetailedEccErrors(handle, errorType, counterType): + c_counts = c_nvmlEccErrorCounts_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDetailedEccErrors") + ret = fn(handle, _nvmlMemoryErrorType_t(errorType), + _nvmlEccCounterType_t(counterType), byref(c_counts)) + _nvmlCheckReturn(ret) + return c_counts + +# Added in 4.304 +def nvmlDeviceGetMemoryErrorCounter(handle, errorType, counterType, locationType): + c_count = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryErrorCounter") + ret = fn(handle, + _nvmlMemoryErrorType_t(errorType), + _nvmlEccCounterType_t(counterType), + _nvmlMemoryLocation_t(locationType), + byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlDeviceGetUtilizationRates(handle): + c_util = c_nvmlUtilization_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetUtilizationRates") + ret = fn(handle, byref(c_util)) + _nvmlCheckReturn(ret) + return c_util + +def nvmlDeviceGetEncoderUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + +def nvmlDeviceGetDecoderUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDecoderUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + +def nvmlDeviceGetJpgUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetJpgUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + +def nvmlDeviceGetOfaUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetOfaUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + +def nvmlDeviceGetPcieReplayCounter(handle): + c_replay = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieReplayCounter") + ret = fn(handle, byref(c_replay)) + _nvmlCheckReturn(ret) + return c_replay.value + +def nvmlDeviceGetDriverModel(handle): + c_currModel = _nvmlDriverModel_t() + c_pendingModel = _nvmlDriverModel_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDriverModel") + ret = fn(handle, byref(c_currModel), byref(c_pendingModel)) + _nvmlCheckReturn(ret) + return [c_currModel.value, c_pendingModel.value] + +# added to API +def nvmlDeviceGetCurrentDriverModel(handle): + return nvmlDeviceGetDriverModel(handle)[0] + +# added to API +def nvmlDeviceGetPendingDriverModel(handle): + return nvmlDeviceGetDriverModel(handle)[1] + +# Added in 2.285 +@convertStrBytes +def nvmlDeviceGetVbiosVersion(handle): + c_version = create_string_buffer(NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVbiosVersion") + ret = fn(handle, c_version, c_uint(NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + +# Added in 2.285 +def nvmlDeviceGetComputeRunningProcesses_v2(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeRunningProcesses_v2") + ret = fn(handle, byref(c_count), None) + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v2_t * c_count.value + c_procs = proc_array() + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + return procs + else: + # error case + raise NVMLError(ret) + +# Added in 2.285 +def nvmlDeviceGetComputeRunningProcesses_v3(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeRunningProcesses_v3") + ret = fn(handle, byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v3_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + +@throwOnVersionMismatch +def nvmlDeviceGetComputeRunningProcesses(handle): + return nvmlDeviceGetComputeRunningProcesses_v3(handle) + +def nvmlDeviceGetGraphicsRunningProcesses_v2(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGraphicsRunningProcesses_v2") + ret = fn(handle, byref(c_count), None) + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v2_t * c_count.value + c_procs = proc_array() + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + return procs + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetGraphicsRunningProcesses_v3(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGraphicsRunningProcesses_v3") + ret = fn(handle, byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v3_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + +@throwOnVersionMismatch +def nvmlDeviceGetGraphicsRunningProcesses(handle): + return nvmlDeviceGetGraphicsRunningProcesses_v3(handle) + +@throwOnVersionMismatch +def nvmlDeviceGetMPSComputeRunningProcesses(handle): + return nvmlDeviceGetMPSComputeRunningProcesses_v3(handle) + +def nvmlDeviceGetMPSComputeRunningProcesses_v2(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMPSComputeRunningProcesses_v2") + ret = fn(handle, byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v2_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetMPSComputeRunningProcesses_v3(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMPSComputeRunningProcesses_v3") + ret = fn(handle, byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v3_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetRunningProcessDetailList(handle, version, mode): + c_processDetailList = c_nvmlProcessDetailList_t() + c_processDetailList.version = version + c_processDetailList.mode = mode + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRunningProcessDetailList") + + # first call to get the size + ret = fn(handle, byref(c_processDetailList)) + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + c_procs = c_nvmlProcessDetail_v1_t * c_processDetailList.numProcArrayEntries + c_processDetailList.procArray = cast((c_procs)(), POINTER(c_nvmlProcessDetail_v1_t)) + + # make the call again + ret = fn(handle, byref(c_processDetailList)) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_processDetailList.numProcArrayEntries): + # use an alternative struct for this object + obj = c_processDetailList.procArray[i] + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + obj.usedGpuMemory = None + if (obj.usedGpuCcProtectedMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + obj.usedGpuCcProtectedMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetAutoBoostedClocksEnabled(handle): + c_isEnabled = _nvmlEnableState_t() + c_defaultIsEnabled = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAutoBoostedClocksEnabled") + ret = fn(handle, byref(c_isEnabled), byref(c_defaultIsEnabled)) + _nvmlCheckReturn(ret) + return [c_isEnabled.value, c_defaultIsEnabled.value] + #Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks + +## Set functions +def nvmlUnitSetLedState(unit, color): + fn = _nvmlGetFunctionPointer("nvmlUnitSetLedState") + ret = fn(unit, _nvmlLedColor_t(color)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetPersistenceMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetPersistenceMode") + ret = fn(handle, _nvmlEnableState_t(mode)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetComputeMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetComputeMode") + ret = fn(handle, _nvmlComputeMode_t(mode)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetEccMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetEccMode") + ret = fn(handle, _nvmlEnableState_t(mode)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceClearEccErrorCounts(handle, counterType): + fn = _nvmlGetFunctionPointer("nvmlDeviceClearEccErrorCounts") + ret = fn(handle, _nvmlEccCounterType_t(counterType)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetDriverModel(handle, model): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDriverModel") + ret = fn(handle, _nvmlDriverModel_t(model)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetAutoBoostedClocksEnabled(handle, enabled): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetAutoBoostedClocksEnabled") + ret = fn(handle, _nvmlEnableState_t(enabled)) + _nvmlCheckReturn(ret) + return None + #Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks + +def nvmlDeviceSetDefaultAutoBoostedClocksEnabled(handle, enabled, flags): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDefaultAutoBoostedClocksEnabled") + ret = fn(handle, _nvmlEnableState_t(enabled), c_uint(flags)) + _nvmlCheckReturn(ret) + return None + #Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks + +def nvmlDeviceSetGpuLockedClocks(handle, minGpuClockMHz, maxGpuClockMHz): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpuLockedClocks") + ret = fn(handle, c_uint(minGpuClockMHz), c_uint(maxGpuClockMHz)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceResetGpuLockedClocks(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetGpuLockedClocks") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetMemoryLockedClocks(handle, minMemClockMHz, maxMemClockMHz): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetMemoryLockedClocks") + ret = fn(handle, c_uint(minMemClockMHz), c_uint(maxMemClockMHz)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceResetMemoryLockedClocks(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetMemoryLockedClocks") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetClkMonStatus(handle, c_clkMonInfo=nvmlClkMonStatus_t()): + isReference = type(c_clkMonInfo) is not nvmlClkMonStatus_t + c_clkMonInfoRef = c_clkMonInfo if isReference else byref(c_clkMonInfo) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClkMonStatus") + ret = fn(handle, c_clkMonInfoRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else c_clkMonInfo + +# Added in 4.304 +def nvmlDeviceSetApplicationsClocks(handle, maxMemClockMHz, maxGraphicsClockMHz): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetApplicationsClocks") + ret = fn(handle, c_uint(maxMemClockMHz), c_uint(maxGraphicsClockMHz)) + _nvmlCheckReturn(ret) + return None + +# Added in 4.304 +def nvmlDeviceResetApplicationsClocks(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetApplicationsClocks") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +# Added in 4.304 +def nvmlDeviceSetPowerManagementLimit(handle, limit): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetPowerManagementLimit") + ret = fn(handle, c_uint(limit)) + _nvmlCheckReturn(ret) + return None + +# Added in 4.304 +def nvmlDeviceSetGpuOperationMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpuOperationMode") + ret = fn(handle, _nvmlGpuOperationMode_t(mode)) + _nvmlCheckReturn(ret) + return None + +# Added in 2.285 +def nvmlEventSetCreate(): + fn = _nvmlGetFunctionPointer("nvmlEventSetCreate") + eventSet = c_nvmlEventSet_t() + ret = fn(byref(eventSet)) + _nvmlCheckReturn(ret) + return eventSet + +# Added in 2.285 +def nvmlDeviceRegisterEvents(handle, eventTypes, eventSet): + fn = _nvmlGetFunctionPointer("nvmlDeviceRegisterEvents") + ret = fn(handle, c_ulonglong(eventTypes), eventSet) + _nvmlCheckReturn(ret) + return None + +# Added in 2.285 +def nvmlDeviceGetSupportedEventTypes(handle): + c_eventTypes = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedEventTypes") + ret = fn(handle, byref(c_eventTypes)) + _nvmlCheckReturn(ret) + return c_eventTypes.value + +# raises NVML_ERROR_TIMEOUT exception on timeout +def nvmlEventSetWait_v2(eventSet, timeoutms): + fn = _nvmlGetFunctionPointer("nvmlEventSetWait_v2") + data = c_nvmlEventData_t() + ret = fn(eventSet, byref(data), c_uint(timeoutms)) + _nvmlCheckReturn(ret) + return data + +def nvmlEventSetWait(eventSet, timeoutms): + return nvmlEventSetWait_v2(eventSet, timeoutms) + +# Added in 2.285 +def nvmlEventSetFree(eventSet): + fn = _nvmlGetFunctionPointer("nvmlEventSetFree") + ret = fn(eventSet) + _nvmlCheckReturn(ret) + return None + +# Added in 3.295 +def nvmlDeviceOnSameBoard(handle1, handle2): + fn = _nvmlGetFunctionPointer("nvmlDeviceOnSameBoard") + onSameBoard = c_int() + ret = fn(handle1, handle2, byref(onSameBoard)) + _nvmlCheckReturn(ret) + return (onSameBoard.value != 0) + +# Added in 3.295 +def nvmlDeviceGetCurrPcieLinkGeneration(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrPcieLinkGeneration") + gen = c_uint() + ret = fn(handle, byref(gen)) + _nvmlCheckReturn(ret) + return gen.value + +# Added in 3.295 +def nvmlDeviceGetMaxPcieLinkGeneration(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxPcieLinkGeneration") + gen = c_uint() + ret = fn(handle, byref(gen)) + _nvmlCheckReturn(ret) + return gen.value + +# Added in 3.295 +def nvmlDeviceGetCurrPcieLinkWidth(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrPcieLinkWidth") + width = c_uint() + ret = fn(handle, byref(width)) + _nvmlCheckReturn(ret) + return width.value + +# Added in 3.295 +def nvmlDeviceGetMaxPcieLinkWidth(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxPcieLinkWidth") + width = c_uint() + ret = fn(handle, byref(width)) + _nvmlCheckReturn(ret) + return width.value + +def nvmlDeviceGetGpuMaxPcieLinkGeneration(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuMaxPcieLinkGeneration") + gen = c_uint() + ret = fn(handle, byref(gen)) + _nvmlCheckReturn(ret) + return gen.value + +# Added in 4.304 +def nvmlDeviceGetSupportedClocksThrottleReasons(handle): + c_reasons= c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedClocksThrottleReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + +def nvmlDeviceGetSupportedClocksEventReasons(handle): + c_reasons= c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedClocksEventReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + +# Added in 4.304 +def nvmlDeviceGetCurrentClocksThrottleReasons(handle): + c_reasons= c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrentClocksThrottleReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + +def nvmlDeviceGetCurrentClocksEventReasons(handle): + c_reasons= c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrentClocksEventReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + +# Added in 5.319 +def nvmlDeviceGetIndex(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetIndex") + c_index = c_uint() + ret = fn(handle, byref(c_index)) + _nvmlCheckReturn(ret) + return c_index.value + +# Added in 5.319 +def nvmlDeviceGetAccountingMode(handle): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingMode") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + +def nvmlDeviceSetAccountingMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetAccountingMode") + ret = fn(handle, _nvmlEnableState_t(mode)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceClearAccountingPids(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceClearAccountingPids") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetAccountingStats(handle, pid): + stats = c_nvmlAccountingStats_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingStats") + ret = fn(handle, c_uint(pid), byref(stats)) + _nvmlCheckReturn(ret) + if (stats.maxMemoryUsage == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + stats.maxMemoryUsage = None + return stats + +def nvmlDeviceGetAccountingPids(handle): + count = c_uint(nvmlDeviceGetAccountingBufferSize(handle)) + pids = (c_uint * count.value)() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingPids") + ret = fn(handle, byref(count), pids) + _nvmlCheckReturn(ret) + return list(map(int, pids[0:count.value])) + +def nvmlDeviceGetAccountingBufferSize(handle): + bufferSize = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingBufferSize") + ret = fn(handle, byref(bufferSize)) + _nvmlCheckReturn(ret) + return int(bufferSize.value) + +def nvmlDeviceGetRetiredPages(device, sourceFilter): + c_source = _nvmlPageRetirementCause_t(sourceFilter) + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPages") + + # First call will get the size + ret = fn(device, c_source, byref(c_count), None) + + # this should only fail with insufficient size + if ((ret != NVML_SUCCESS) and + (ret != NVML_ERROR_INSUFFICIENT_SIZE)): + raise NVMLError(ret) + + # call again with a buffer + # oversize the array for the rare cases where additional pages + # are retired between NVML calls + c_count.value = c_count.value * 2 + 5 + page_array = c_ulonglong * c_count.value + c_pages = page_array() + ret = fn(device, c_source, byref(c_count), c_pages) + _nvmlCheckReturn(ret) + return list(map(int, c_pages[0:c_count.value])) + +def nvmlDeviceGetRetiredPages_v2(device, sourceFilter): + c_source = _nvmlPageRetirementCause_t(sourceFilter) + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPages_v2") + + # First call will get the size + ret = fn(device, c_source, byref(c_count), None) + + # this should only fail with insufficient size + if ((ret != NVML_SUCCESS) and + (ret != NVML_ERROR_INSUFFICIENT_SIZE)): + raise NVMLError(ret) + + # call again with a buffer + # oversize the array for the rare cases where additional pages + # are retired between NVML calls + c_count.value = c_count.value * 2 + 5 + page_array = c_ulonglong * c_count.value + c_pages = page_array() + times_array = c_ulonglong * c_count.value + c_times = times_array() + ret = fn(device, c_source, byref(c_count), c_pages, c_times) + _nvmlCheckReturn(ret) + return [ { 'address': int(c_pages[i]), 'timestamp': int(c_times[i]) } for i in range(c_count.value) ]; + +def nvmlDeviceGetRetiredPagesPendingStatus(device): + c_pending = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPagesPendingStatus") + ret = fn(device, byref(c_pending)) + _nvmlCheckReturn(ret) + return int(c_pending.value) + +def nvmlDeviceGetAPIRestriction(device, apiType): + c_permission = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAPIRestriction") + ret = fn(device, _nvmlRestrictedAPI_t(apiType), byref(c_permission)) + _nvmlCheckReturn(ret) + return int(c_permission.value) + +def nvmlDeviceSetAPIRestriction(handle, apiType, isRestricted): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetAPIRestriction") + ret = fn(handle, _nvmlRestrictedAPI_t(apiType), _nvmlEnableState_t(isRestricted)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetBridgeChipInfo(handle): + bridgeHierarchy = c_nvmlBridgeChipHierarchy_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBridgeChipInfo") + ret = fn(handle, byref(bridgeHierarchy)) + _nvmlCheckReturn(ret) + return bridgeHierarchy + +def nvmlDeviceGetSamples(device, sampling_type, timeStamp): + c_sampling_type = _nvmlSamplingType_t(sampling_type) + c_time_stamp = c_ulonglong(timeStamp) + c_sample_count = c_uint(0) + c_sample_value_type = _nvmlValueType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSamples") + + ## First Call gets the size + ret = fn(device, c_sampling_type, c_time_stamp, byref(c_sample_value_type), byref(c_sample_count), None) + + # Stop if this fails + if (ret != NVML_SUCCESS): + raise NVMLError(ret) + + sampleArray = c_sample_count.value * c_nvmlSample_t + c_samples = sampleArray() + ret = fn(device, c_sampling_type, c_time_stamp, byref(c_sample_value_type), byref(c_sample_count), c_samples) + _nvmlCheckReturn(ret) + return (c_sample_value_type.value, c_samples[0:c_sample_count.value]) + +def nvmlDeviceGetViolationStatus(device, perfPolicyType): + c_perfPolicy_type = _nvmlPerfPolicyType_t(perfPolicyType) + c_violTime = c_nvmlViolationTime_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetViolationStatus") + + ## Invoke the method to get violation time + ret = fn(device, c_perfPolicy_type, byref(c_violTime)) + _nvmlCheckReturn(ret) + return c_violTime + +def nvmlDeviceGetPcieThroughput(device, counter): + c_util = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieThroughput") + ret = fn(device, _nvmlPcieUtilCounter_t(counter), byref(c_util)) + _nvmlCheckReturn(ret) + return c_util.value + +def nvmlSystemGetTopologyGpuSet(cpuNumber): + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlSystemGetTopologyGpuSet") + + # First call will get the size + ret = fn(cpuNumber, byref(c_count), None) + + if ret != NVML_SUCCESS: + raise NVMLError(ret) + # call again with a buffer + device_array = c_nvmlDevice_t * c_count.value + c_devices = device_array() + ret = fn(cpuNumber, byref(c_count), c_devices) + _nvmlCheckReturn(ret) + return list(c_devices[0:c_count.value]) + +def nvmlDeviceGetTopologyNearestGpus(device, level): + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTopologyNearestGpus") + + # First call will get the size + ret = fn(device, level, byref(c_count), None) + + if ret != NVML_SUCCESS: + raise NVMLError(ret) + + # call again with a buffer + device_array = c_nvmlDevice_t * c_count.value + c_devices = device_array() + ret = fn(device, level, byref(c_count), c_devices) + _nvmlCheckReturn(ret) + return list(c_devices[0:c_count.value]) + +def nvmlDeviceGetTopologyCommonAncestor(device1, device2): + c_level = _nvmlGpuTopologyLevel_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTopologyCommonAncestor") + ret = fn(device1, device2, byref(c_level)) + _nvmlCheckReturn(ret) + return c_level.value + +def nvmlDeviceGetNvLinkUtilizationCounter(device, link, counter): + c_rxcounter = c_ulonglong() + c_txcounter = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkUtilizationCounter") + ret = fn(device, link, counter, byref(c_rxcounter), byref(c_txcounter)) + _nvmlCheckReturn(ret) + return (c_rxcounter.value, c_txcounter.value) + +def nvmlDeviceFreezeNvLinkUtilizationCounter(device, link, counter, freeze): + fn = _nvmlGetFunctionPointer("nvmlDeviceFreezeNvLinkUtilizationCounter") + ret = fn(device, link, counter, freeze) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceResetNvLinkUtilizationCounter(device, link, counter): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetNvLinkUtilizationCounter") + ret = fn(device, link, counter) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetNvLinkUtilizationControl(device, link, counter, control, reset): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvLinkUtilizationControl") + ret = fn(device, link, counter, byref(control), reset) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetNvLinkUtilizationControl(device, link, counter): + c_control = nvmlNvLinkUtilizationControl_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkUtilizationControl") + ret = fn(device, link, counter, byref(c_control)) + _nvmlCheckReturn(ret) + return c_control + +def nvmlDeviceGetNvLinkCapability(device, link, capability): + c_capResult = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkCapability") + ret = fn(device, link, capability, byref(c_capResult)) + _nvmlCheckReturn(ret) + return c_capResult.value + +def nvmlDeviceGetNvLinkErrorCounter(device, link, counter): + c_result = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkErrorCounter") + ret = fn(device, link, counter, byref(c_result)) + _nvmlCheckReturn(ret) + return c_result.value + +def nvmlDeviceResetNvLinkErrorCounters(device, link): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetNvLinkErrorCounters") + ret = fn(device, link) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetNvLinkRemotePciInfo(device, link): + c_pci = nvmlPciInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkRemotePciInfo_v2") + ret = fn(device, link, byref(c_pci)) + _nvmlCheckReturn(ret) + return c_pci + +def nvmlDeviceGetNvLinkRemoteDeviceType(handle, link): + c_type = _nvmlNvLinkDeviceType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkRemoteDeviceType") + ret = fn(handle, link, byref(c_type)) + _nvmlCheckReturn(ret) + return c_type.value + +def nvmlDeviceGetNvLinkState(device, link): + c_isActive = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkState") + ret = fn(device, link, byref(c_isActive)) + _nvmlCheckReturn(ret) + return c_isActive.value + +def nvmlDeviceGetNvLinkVersion(device, link): + c_version = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkVersion") + ret = fn(device, link, byref(c_version)) + _nvmlCheckReturn(ret) + return c_version.value + +def nvmlDeviceModifyDrainState(pciInfo, newState): + fn = _nvmlGetFunctionPointer("nvmlDeviceModifyDrainState") + ret = fn(pointer(pciInfo), newState) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceQueryDrainState(pciInfo): + c_newState = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceQueryDrainState") + ret = fn(pointer(pciInfo), byref(c_newState)) + _nvmlCheckReturn(ret) + return c_newState.value + +def nvmlDeviceRemoveGpu(pciInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceRemoveGpu") + ret = fn(pointer(pciInfo)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceDiscoverGpus(pciInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceDiscoverGpus") + ret = fn(pointer(pciInfo)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetFieldValues(handle, fieldIds): + values_arr = c_nvmlFieldValue_t * len(fieldIds) + values = values_arr() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFieldValues") + + for i, fieldId in enumerate(fieldIds): + try: + (values[i].fieldId, values[i].scopeId) = fieldId + except TypeError: + values[i].fieldId = fieldId + + ret = fn(handle, c_int32(len(fieldIds)), byref(values)) + _nvmlCheckReturn(ret) + return values + +def nvmlDeviceClearFieldValues(handle, fieldIds): + values_arr = c_nvmlFieldValue_t * len(fieldIds) + values = values_arr() + fn = _nvmlGetFunctionPointer("nvmlDeviceClearFieldValues") + + for i, fieldId in enumerate(fieldIds): + try: + (values[i].fieldId, values[i].scopeId) = fieldId + except TypeError: + values[i].fieldId = fieldId + + ret = fn(handle, c_int32(len(fieldIds)), byref(values)) + _nvmlCheckReturn(ret) + return values + +def nvmlDeviceGetVirtualizationMode(handle): + c_virtualization_mode = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVirtualizationMode") + ret = fn(handle, byref(c_virtualization_mode)) + _nvmlCheckReturn(ret) + return c_virtualization_mode.value + +def nvmlDeviceSetVirtualizationMode(handle, virtualization_mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVirtualizationMode") + return fn(handle, virtualization_mode) + +def nvmlDeviceGetVgpuHeterogeneousMode(handle): + c_vgpuHeterogeneousMode = c_nvmlVgpuHeterogeneousMode_v1_t(0) + c_vgpuHeterogeneousMode.version = VgpuHeterogeneousMode_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuHeterogeneousMode") + ret = fn(handle, byref(c_vgpuHeterogeneousMode)) + _nvmlCheckReturn(ret) + return c_vgpuHeterogeneousMode.mode + +def nvmlDeviceSetVgpuHeterogeneousMode(handle, heterogeneous_mode): + c_vgpuHeterogeneousMode = c_nvmlVgpuHeterogeneousMode_v1_t(0) + c_vgpuHeterogeneousMode.version = VgpuHeterogeneousMode_v1 + c_vgpuHeterogeneousMode.mode = heterogeneous_mode + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVgpuHeterogeneousMode") + ret = fn(handle, byref(c_vgpuHeterogeneousMode)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlVgpuInstanceGetPlacementId(vgpuInstance): + c_placement = c_nvmlVgpuPlacementId_v1_t(0) + c_placement.version = VgpuPlacementId_v1 + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetPlacementId") + ret = fn(vgpuInstance, byref(c_placement)) + _nvmlCheckReturn(ret) + return c_placement.placementId + +def nvmlDeviceGetVgpuTypeSupportedPlacements(handle, vgpuTypeId, mode=0, version=1): + c_max_instances = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstances") + ret = fn(handle, vgpuTypeId, byref(c_max_instances)) + _nvmlCheckReturn(ret) + + if version == 2: + c_vgpu_placements = c_nvmlVgpuPlacementList_v2_t() + c_vgpu_placements.version = VgpuPlacementList_v2 + c_vgpu_placements.count = c_max_instances.value + c_vgpu_placements.mode = mode + elif version == 1: + c_vgpu_placements = c_nvmlVgpuPlacementList_v1_t() + c_vgpu_placements.version = VgpuPlacementList_v1 + else: + raise NVMLError(NVML_ERROR_ARGUMENT_VERSION_MISMATCH) + + c_placements = c_uint * c_max_instances.value + c_vgpu_placements.placementIds = c_placements() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuTypeSupportedPlacements") + ret = fn(handle, vgpuTypeId, byref(c_vgpu_placements)) + _nvmlCheckReturn(ret) + return c_vgpu_placements + +def nvmlDeviceGetVgpuTypeCreatablePlacements(handle, vgpuTypeId, version=1): + c_max_instances = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstances") + ret = fn(handle, vgpuTypeId, byref(c_max_instances)) + _nvmlCheckReturn(ret) + + if version == 2: + c_vgpu_placements = c_nvmlVgpuPlacementList_v2_t() + c_vgpu_placements.version = VgpuPlacementList_v2 + c_vgpu_placements.count = c_max_instances.value + elif version == 1: + c_vgpu_placements = c_nvmlVgpuPlacementList_v1_t() + c_vgpu_placements.version = VgpuPlacementList_v1 + + c_placements = c_uint * c_max_instances.value + c_vgpu_placements.placementIds = c_placements() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuTypeCreatablePlacements") + ret = fn(handle, vgpuTypeId, byref(c_vgpu_placements)) + _nvmlCheckReturn(ret) + return c_vgpu_placements + +def nvmlGetVgpuDriverCapabilities(capability): + c_capResult = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGetVgpuDriverCapabilities") + ret = fn(_nvmlVgpuDriverCapability_t(capability), byref(c_capResult)) + _nvmlCheckReturn(ret) + return c_capResult.value + +def nvmlDeviceGetVgpuCapabilities(handle, capability): + c_capResult = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuCapabilities") + ret = fn(handle, _nvmlDeviceVgpuCapability_t(capability), byref(c_capResult)) + _nvmlCheckReturn(ret) + return c_capResult.value + +def nvmlDeviceSetVgpuCapabilities(handle, capability, state): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVgpuCapabilities") + ret = fn(handle, _nvmlDeviceVgpuCapability_t(capability), state) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetSupportedVgpus(handle): + # first call to get the size + c_vgpu_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedVgpus") + ret = fn(handle, byref(c_vgpu_count), None) + + if (ret == NVML_SUCCESS): + # special case, no supported vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + vgpu_type_ids_array = _nvmlVgpuTypeId_t * c_vgpu_count.value + c_vgpu_type_ids = vgpu_type_ids_array() + + # make the call again + ret = fn(handle, byref(c_vgpu_count), c_vgpu_type_ids) + _nvmlCheckReturn(ret) + vgpus = [] + for i in range(c_vgpu_count.value): + vgpus.append(c_vgpu_type_ids[i]) + return vgpus + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetCreatableVgpus(handle): + # first call to get the size + c_vgpu_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCreatableVgpus") + ret = fn(handle, byref(c_vgpu_count), None) + + if (ret == NVML_SUCCESS): + # special case, no supported vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + vgpu_type_ids_array = _nvmlVgpuTypeId_t * c_vgpu_count.value + c_vgpu_type_ids = vgpu_type_ids_array() + + # make the call again + ret = fn(handle, byref(c_vgpu_count), c_vgpu_type_ids) + _nvmlCheckReturn(ret) + vgpus = [] + for i in range(c_vgpu_count.value): + vgpus.append(c_vgpu_type_ids[i]) + return vgpus + else: + # error case + raise NVMLError(ret) + +def nvmlVgpuTypeGetGpuInstanceProfileId(vgpuTypeId): + c_profile_id = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetGpuInstanceProfileId") + ret = fn(vgpuTypeId, byref(c_profile_id)) + _nvmlCheckReturn(ret) + return (c_profile_id.value) + +@convertStrBytes +def nvmlVgpuTypeGetClass(vgpuTypeId): + c_class = create_string_buffer(NVML_DEVICE_NAME_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_NAME_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetClass") + ret = fn(vgpuTypeId, c_class, byref(c_buffer_size)) + _nvmlCheckReturn(ret) + return c_class.value + +@convertStrBytes +def nvmlVgpuTypeGetName(vgpuTypeId): + c_name = create_string_buffer(NVML_DEVICE_NAME_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_NAME_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetName") + ret = fn(vgpuTypeId, c_name, byref(c_buffer_size)) + _nvmlCheckReturn(ret) + return c_name.value + +def nvmlVgpuTypeGetDeviceID(vgpuTypeId): + c_device_id = c_ulonglong(0) + c_subsystem_id = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetDeviceID") + ret = fn(vgpuTypeId, byref(c_device_id), byref(c_subsystem_id)) + _nvmlCheckReturn(ret) + return (c_device_id.value, c_subsystem_id.value) + +def nvmlVgpuTypeGetFramebufferSize(vgpuTypeId): + c_fb_size = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetFramebufferSize") + ret = fn(vgpuTypeId, byref(c_fb_size)) + _nvmlCheckReturn(ret) + return c_fb_size.value + +def nvmlVgpuTypeGetNumDisplayHeads(vgpuTypeId): + c_num_heads = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetNumDisplayHeads") + ret = fn(vgpuTypeId, byref(c_num_heads)) + _nvmlCheckReturn(ret) + return c_num_heads.value + +def nvmlVgpuTypeGetResolution(vgpuTypeId): + c_xdim = c_uint(0) + c_ydim = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetResolution") + ret = fn(vgpuTypeId, 0, byref(c_xdim), byref(c_ydim)) + _nvmlCheckReturn(ret) + return (c_xdim.value, c_ydim.value) + +@convertStrBytes +def nvmlVgpuTypeGetLicense(vgpuTypeId): + c_license = create_string_buffer(NVML_GRID_LICENSE_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_GRID_LICENSE_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetLicense") + ret = fn(vgpuTypeId, c_license, c_buffer_size) + _nvmlCheckReturn(ret) + return c_license.value + +def nvmlVgpuTypeGetFrameRateLimit(vgpuTypeId): + c_frl_config = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetFrameRateLimit") + ret = fn(vgpuTypeId, byref(c_frl_config)) + _nvmlCheckReturn(ret) + return c_frl_config.value + +def nvmlVgpuTypeGetGspHeapSize(vgpuTypeId): + c_gsp_heap = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetGspHeapSize") + ret = fn(vgpuTypeId, byref(c_gsp_heap)) + _nvmlCheckReturn(ret) + return c_gsp_heap.value + +def nvmlVgpuTypeGetFbReservation(vgpuTypeId): + c_fb_reservation = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetFbReservation") + ret = fn(vgpuTypeId, byref(c_fb_reservation)) + _nvmlCheckReturn(ret) + return c_fb_reservation.value + +def nvmlVgpuInstanceGetRuntimeStateSize(vgpuInstance): + c_runtime_state = nvmlVgpuRuntimeState_v1_t() + c_runtime_state.version = VgpuRuntimeState_v1 + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetRuntimeStateSize") + ret = fn(vgpuInstance, byref(c_runtime_state)) + _nvmlCheckReturn(ret) + return c_runtime_state + +def nvmlVgpuTypeGetMaxInstances(handle, vgpuTypeId): + c_max_instances = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstances") + ret = fn(handle, vgpuTypeId, byref(c_max_instances)) + _nvmlCheckReturn(ret) + return c_max_instances.value + +def nvmlVgpuTypeGetMaxInstancesPerVm(vgpuTypeId): + c_max_instances_per_vm = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstancesPerVm") + ret = fn(vgpuTypeId, byref(c_max_instances_per_vm)) + _nvmlCheckReturn(ret) + return c_max_instances_per_vm.value + +def nvmlVgpuTypeGetBAR1Info(vgpuTypeId): + c_bar1Info = c_nvmlVgpuTypeBar1Info_v1_t(0) + c_bar1Info.version = VgpuTypeBar1Info_v1 + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetBAR1Info") + ret = fn(vgpuTypeId, byref(c_bar1Info)) + _nvmlCheckReturn(ret) + return c_bar1Info + +def nvmlDeviceGetActiveVgpus(handle): + # first call to get the size + c_vgpu_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetActiveVgpus") + ret = fn(handle, byref(c_vgpu_count), None) + + if (ret == NVML_SUCCESS): + # special case, no active vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + vgpu_instance_array = _nvmlVgpuInstance_t * c_vgpu_count.value + c_vgpu_instances = vgpu_instance_array() + + # make the call again + ret = fn(handle, byref(c_vgpu_count), c_vgpu_instances) + _nvmlCheckReturn(ret) + vgpus = [] + for i in range(c_vgpu_count.value): + vgpus.append(c_vgpu_instances[i]) + return vgpus + else: + # error case + raise NVMLError(ret) + +@convertStrBytes +def nvmlVgpuInstanceGetVmID(vgpuInstance): + c_vm_id = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_GRID_LICENSE_BUFFER_SIZE) + c_vm_id_type = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetVmID") + ret = fn(vgpuInstance, byref(c_vm_id), c_buffer_size, byref(c_vm_id_type)) + _nvmlCheckReturn(ret) + return (c_vm_id.value, c_vm_id_type.value) + +@convertStrBytes +def nvmlVgpuInstanceGetUUID(vgpuInstance): + c_uuid = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_UUID_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetUUID") + ret = fn(vgpuInstance, byref(c_uuid), c_buffer_size) + _nvmlCheckReturn(ret) + return c_uuid.value + +@convertStrBytes +def nvmlVgpuInstanceGetMdevUUID(vgpuInstance): + c_uuid = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_UUID_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetMdevUUID") + ret = fn(vgpuInstance, byref(c_uuid), c_buffer_size) + _nvmlCheckReturn(ret) + return c_uuid.value + +@convertStrBytes +def nvmlVgpuInstanceGetVmDriverVersion(vgpuInstance): + c_driver_version = create_string_buffer(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetVmDriverVersion") + ret = fn(vgpuInstance, byref(c_driver_version), c_buffer_size) + _nvmlCheckReturn(ret) + return c_driver_version.value + +def nvmlVgpuInstanceGetLicenseStatus(vgpuInstance): + c_license_status = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetLicenseStatus") + ret = fn(vgpuInstance, byref(c_license_status)) + _nvmlCheckReturn(ret) + return c_license_status.value + +def nvmlVgpuInstanceGetLicenseInfo_v2(vgpuInstance): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetLicenseInfo_v2") + c_license_info = c_nvmlVgpuLicenseInfo_t() + ret = fn(vgpuInstance, byref(c_license_info)) + _nvmlCheckReturn(ret) + return c_license_info + +def nvmlVgpuInstanceGetLicenseInfo(vgpuInstance): + return nvmlVgpuInstanceGetLicenseInfo_v2(vgpuInstance) + +def nvmlVgpuInstanceGetFrameRateLimit(vgpuInstance): + c_frl = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFrameRateLimit") + ret = fn(vgpuInstance, byref(c_frl)) + _nvmlCheckReturn(ret) + return c_frl.value + +def nvmlVgpuInstanceGetEccMode(vgpuInstance): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEccMode") + ret = fn(vgpuInstance, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + +def nvmlVgpuInstanceGetType(vgpuInstance): + c_vgpu_type = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetType") + ret = fn(vgpuInstance, byref(c_vgpu_type)) + _nvmlCheckReturn(ret) + return c_vgpu_type.value + +def nvmlVgpuInstanceGetEncoderCapacity(vgpuInstance): + c_encoder_capacity = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderCapacity") + ret = fn(vgpuInstance, byref(c_encoder_capacity)) + _nvmlCheckReturn(ret) + return c_encoder_capacity.value + +def nvmlVgpuInstanceSetEncoderCapacity(vgpuInstance, encoder_capacity): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceSetEncoderCapacity") + return fn(vgpuInstance, encoder_capacity) + +def nvmlVgpuInstanceGetFbUsage(vgpuInstance): + c_fb_usage = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFbUsage") + ret = fn(vgpuInstance, byref(c_fb_usage)) + _nvmlCheckReturn(ret) + return c_fb_usage.value + +def nvmlVgpuTypeGetCapabilities(vgpuTypeId, capability): + c_cap_result = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetCapabilities") + ret = fn(vgpuTypeId, _nvmlVgpuCapability_t(capability), byref(c_cap_result)) + _nvmlCheckReturn(ret) + return (c_cap_result.value) + +def nvmlVgpuInstanceGetGpuInstanceId(vgpuInstance): + c_id = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetGpuInstanceId") + ret = fn(vgpuInstance, byref(c_id)) + _nvmlCheckReturn(ret) + return (c_id.value) + +@convertStrBytes +def nvmlVgpuInstanceGetGpuPciId(vgpuInstance): + c_vgpuPciId = create_string_buffer(NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetGpuPciId") + ret = fn(vgpuInstance, c_vgpuPciId, byref(c_uint(NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE))) + _nvmlCheckReturn(ret) + return c_vgpuPciId.value + +def nvmlDeviceGetVgpuUtilization(handle, timeStamp): + # first call to get the size + c_vgpu_count = c_uint(0) + c_time_stamp = c_ulonglong(timeStamp) + c_sample_value_type = _nvmlValueType_t() + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuUtilization") + ret = fn(handle, c_time_stamp, byref(c_sample_value_type), byref(c_vgpu_count), None) + + if (ret == NVML_SUCCESS): + # special case, no active vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_vgpu_count.value * c_nvmlVgpuInstanceUtilizationSample_t + c_samples = sampleArray() + + # make the call again + ret = fn(handle, c_time_stamp, byref(c_sample_value_type), byref(c_vgpu_count), c_samples) + _nvmlCheckReturn(ret) + + return c_samples[0:c_vgpu_count.value] + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetVgpuInstancesUtilizationInfo(handle, timeStamp): + # first call to get the size + c_time_stamp = c_ulonglong(timeStamp) + c_vgpuUtilInfo = c_nvmlVgpuInstancesUtilizationInfo_v1_t(0) + c_vgpuUtilInfo.version = VgpuInstancesUtilizationInfo_v1 + c_vgpuUtilInfo.sampleValType = _nvmlValueType_t() + c_vgpuUtilInfo.vgpuInstanceCount = c_uint(0) + c_vgpuUtilInfo.lastSeenTimeStamp = c_time_stamp + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuInstancesUtilizationInfo") + ret = fn(handle, byref(c_vgpuUtilInfo)) + + if (ret == NVML_SUCCESS): + # special case, no active vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_vgpuUtilInfo.vgpuInstanceCount * c_nvmlVgpuInstanceUtilizationInfo_v1_t + c_samples = sampleArray() + c_vgpuUtilInfo.vgpuUtilArray = c_samples + + # make the call again + ret = fn(handle, byref(c_vgpuUtilInfo)) + _nvmlCheckReturn(ret) + + return c_samples[0:c_vgpuUtilInfo.vgpuInstanceCount] + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetP2PStatus(device1, device2, p2pIndex): + c_p2pstatus = _nvmlGpuP2PStatus_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetP2PStatus") + ret = fn(device1, device2,p2pIndex, byref(c_p2pstatus)) + _nvmlCheckReturn(ret) + return c_p2pstatus.value + +def nvmlDeviceGetGridLicensableFeatures_v4(handle): + c_get_grid_licensable_features = c_nvmlGridLicensableFeatures_v4_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGridLicensableFeatures_v4") + ret = fn(handle, byref(c_get_grid_licensable_features)) + _nvmlCheckReturn(ret) + + return (c_get_grid_licensable_features) + +def nvmlDeviceGetGridLicensableFeatures(handle): + return nvmlDeviceGetGridLicensableFeatures_v4(handle) + +def nvmlDeviceGetGspFirmwareVersion(handle, version=None): + isUserDefined = version is not None + if not isUserDefined: + version = (c_char * NVML_GSP_FIRMWARE_VERSION_BUF_SIZE)() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGspFirmwareVersion") + ret = fn(handle, version) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isUserDefined else version.value + +def nvmlDeviceGetGspFirmwareMode(handle, isEnabled=c_uint(), defaultMode=c_uint()): + isReference = type(isEnabled) is not c_uint + isEnabledRef = isEnabled if isReference else byref(isEnabled) + defaultModeRef = defaultMode if isReference else byref(defaultMode) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGspFirmwareMode") + ret = fn(handle, isEnabledRef, defaultModeRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else [isEnabled.value, defaultMode.value] + +def nvmlDeviceGetEncoderCapacity(handle, encoderQueryType): + c_encoder_capacity = c_ulonglong(0) + c_encoderQuery_type = _nvmlEncoderQueryType_t(encoderQueryType) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderCapacity") + ret = fn(handle, c_encoderQuery_type, byref(c_encoder_capacity)) + _nvmlCheckReturn(ret) + return c_encoder_capacity.value + +def nvmlDeviceGetVgpuProcessUtilization(handle, timeStamp): + # first call to get the size + c_vgpu_count = c_uint(0) + c_time_stamp = c_ulonglong(timeStamp) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuProcessUtilization") + ret = fn(handle, c_time_stamp, byref(c_vgpu_count), None) + + if (ret == NVML_SUCCESS): + # special case, no active vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_vgpu_count.value * c_nvmlVgpuProcessUtilizationSample_t + c_samples = sampleArray() + + # make the call again + ret = fn(handle, c_time_stamp, byref(c_vgpu_count), c_samples) + _nvmlCheckReturn(ret) + + return c_samples[0:c_vgpu_count.value] + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetVgpuProcessesUtilizationInfo(handle, timeStamp): + # first call to get the size + c_time_stamp = c_ulonglong(timeStamp) + c_vgpuProcUtilInfo = c_nvmlVgpuProcessesUtilizationInfo_v1_t(0) + c_vgpuProcUtilInfo.version = VgpuProcessesUtilizationInfo_v1 + c_vgpuProcUtilInfo.vgpuProcessCount = c_uint(0) + c_vgpuProcUtilInfo.lastSeenTimeStamp = c_time_stamp + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuProcessesUtilizationInfo") + ret = fn(handle, byref(c_vgpuProcUtilInfo)) + + if (ret == NVML_SUCCESS): + # special case, no active vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_vgpuProcUtilInfo.vgpuProcessCount * c_nvmlVgpuProcessUtilizationInfo_v1_t + c_samples = sampleArray() + c_vgpuProcUtilInfo.vgpuProcUtilArray = c_samples + + # make the call again + ret = fn(handle, byref(c_vgpuProcUtilInfo)) + _nvmlCheckReturn(ret) + + return c_samples[0:c_vgpuProcUtilInfo.vgpuProcessCount] + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetEncoderStats(handle): + c_encoderCount = c_ulonglong(0) + c_encodeFps = c_ulonglong(0) + c_encoderLatency = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderStats") + ret = fn(handle, byref(c_encoderCount), byref(c_encodeFps), byref(c_encoderLatency)) + _nvmlCheckReturn(ret) + return (c_encoderCount.value, c_encodeFps.value, c_encoderLatency.value) + +def nvmlDeviceGetEncoderSessions(handle): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderSessions") + ret = fn(handle, byref(c_session_count), None) + + if (ret == NVML_SUCCESS): + if (c_session_count.value != 0): + # typical case + session_array = c_nvmlEncoderSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(handle, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetFBCStats(handle): + c_fbcStats = c_nvmlFBCStats_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFBCStats") + ret = fn(handle, byref(c_fbcStats)) + _nvmlCheckReturn(ret) + return c_fbcStats + +def nvmlDeviceGetFBCSessions(handle): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFBCSessions") + ret = fn(handle, byref(c_session_count), None) + + if (ret == NVML_SUCCESS): + if (c_session_count.value != 0): + # typical case + session_array = c_nvmlFBCSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(handle, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + +def nvmlVgpuInstanceGetEncoderStats(vgpuInstance): + c_encoderCount = c_ulonglong(0) + c_encodeFps = c_ulonglong(0) + c_encoderLatency = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderStats") + ret = fn(vgpuInstance, byref(c_encoderCount), byref(c_encodeFps), byref(c_encoderLatency)) + _nvmlCheckReturn(ret) + return (c_encoderCount.value, c_encodeFps.value, c_encoderLatency.value) + +def nvmlVgpuInstanceGetEncoderSessions(vgpuInstance): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderSessions") + ret = fn(vgpuInstance, byref(c_session_count), None) + + if (ret == NVML_SUCCESS): + if (c_session_count.value != 0): + # typical case + session_array = c_nvmlEncoderSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(vgpuInstance, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + +def nvmlVgpuInstanceGetFBCStats(vgpuInstance): + c_fbcStats = c_nvmlFBCStats_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFBCStats") + ret = fn(vgpuInstance, byref(c_fbcStats)) + _nvmlCheckReturn(ret) + return c_fbcStats + +def nvmlVgpuInstanceGetFBCSessions(vgpuInstance): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFBCSessions") + ret = fn(vgpuInstance, byref(c_session_count), None) + + if (ret == NVML_SUCCESS): + if (c_session_count.value != 0): + # typical case + session_array = c_nvmlFBCSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(vgpuInstance, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetProcessUtilization(handle, timeStamp): + # first call to get the size + c_count = c_uint(0) + c_time_stamp = c_ulonglong(timeStamp) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetProcessUtilization") + ret = fn(handle, None, byref(c_count), c_time_stamp) + + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_count.value * c_nvmlProcessUtilizationSample_t + c_samples = sampleArray() + + # make the call again + ret = fn(handle, c_samples, byref(c_count), c_time_stamp) + _nvmlCheckReturn(ret) + + return c_samples[0:c_count.value] + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetProcessesUtilizationInfo(handle, timeStamp): + # first call to get the size + c_time_stamp = c_ulonglong(timeStamp) + c_processesUtilInfo = c_nvmlProcessesUtilizationInfo_v1_t(0) + c_processesUtilInfo.version = ProcessesUtilizationInfo_v1 + c_processesUtilInfo.processSamplesCount = c_uint(0) + c_processesUtilInfo.lastSeenTimeStamp = c_time_stamp + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetProcessesUtilizationInfo") + ret = fn(handle, byref(c_processesUtilInfo)) + + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_processesUtilInfo.processSamplesCount * c_nvmlProcessUtilizationInfo_v1_t + c_samples = sampleArray() + c_processesUtilInfo.procUtilArray = c_samples + + # make the call again + ret = fn(handle, byref(c_processesUtilInfo)) + _nvmlCheckReturn(ret) + + return c_samples[0:c_processesUtilInfo.processSamplesCount] + else: + # error case + raise NVMLError(ret) + +def nvmlVgpuInstanceGetMetadata(vgpuInstance): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetMetadata") + c_vgpuMetadata = c_nvmlVgpuMetadata_t() + c_bufferSize = c_uint(0) + # Make the first NVML API call to get the c_bufferSize value. + # We have already allocated required buffer above. + ret = fn(vgpuInstance, byref(c_vgpuMetadata), byref(c_bufferSize)) + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + ret = fn(vgpuInstance, byref(c_vgpuMetadata), byref(c_bufferSize)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return c_vgpuMetadata + +def nvmlDeviceGetVgpuMetadata(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuMetadata") + c_vgpuPgpuMetadata = c_nvmlVgpuPgpuMetadata_t() + c_bufferSize = c_uint(0) + # Make the first NVML API call to get the c_bufferSize value. + # We have already allocated required buffer above. + ret = fn(handle, byref(c_vgpuPgpuMetadata), byref(c_bufferSize)) + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + ret = fn(handle, byref(c_vgpuPgpuMetadata), byref(c_bufferSize)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return c_vgpuPgpuMetadata + +def nvmlGetVgpuCompatibility(vgpuMetadata, pgpuMetadata): + fn = _nvmlGetFunctionPointer("nvmlGetVgpuCompatibility") + c_vgpuPgpuCompatibility = c_nvmlVgpuPgpuCompatibility_t() + ret = fn(byref(vgpuMetadata), byref(pgpuMetadata), byref(c_vgpuPgpuCompatibility)) + _nvmlCheckReturn(ret) + return c_vgpuPgpuCompatibility + +@convertStrBytes +def nvmlDeviceGetPgpuMetadataString(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPgpuMetadataString") + c_pgpuMetadata = create_string_buffer(NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE) + c_bufferSize = c_uint(0) + # Make the first NVML API call to get the c_bufferSize value. + # We have already allocated required buffer above. + ret = fn(handle, byref(c_pgpuMetadata), byref(c_bufferSize)) + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + ret = fn(handle, byref(c_pgpuMetadata), byref(c_bufferSize)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return (c_pgpuMetadata.value, c_bufferSize.value) + +def nvmlDeviceGetVgpuSchedulerLog(handle): + c_vgpu_sched_log = c_nvmlVgpuSchedulerLog_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuSchedulerLog") + ret = fn(handle, byref(c_vgpu_sched_log)) + _nvmlCheckReturn(ret) + return c_vgpu_sched_log + +def nvmlDeviceGetVgpuSchedulerState(handle): + c_vgpu_sched_state = c_nvmlVgpuSchedulerGetState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuSchedulerState") + ret = fn(handle, byref(c_vgpu_sched_state)) + _nvmlCheckReturn(ret) + return c_vgpu_sched_state + +def nvmlDeviceGetVgpuSchedulerCapabilities(handle): + c_vgpu_sched_caps = c_nvmlVgpuSchedulerCapabilities_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuSchedulerCapabilities") + ret = fn(handle, byref(c_vgpu_sched_caps)) + _nvmlCheckReturn(ret) + return c_vgpu_sched_caps + +def nvmlDeviceSetVgpuSchedulerState(handle, sched_state): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVgpuSchedulerState") + ret = fn(handle, byref(sched_state)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlSetVgpuVersion(vgpuVersion): + fn = _nvmlGetFunctionPointer("nvmlSetVgpuVersion") + ret = fn(byref(vgpuVersion)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGetVgpuVersion(supported=None, current=None): + isUserDefined = (supported is not None) or (current is not None) + if not isUserDefined: + supported = c_nvmlVgpuVersion_t() + current = c_nvmlVgpuVersion_t() + fn = _nvmlGetFunctionPointer("nvmlGetVgpuVersion") + ret = fn(byref(supported), byref(current)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isUserDefined else [(supported.minVersion, + supported.maxVersion), + (current.minVersion, + current.maxVersion)] + +def nvmlVgpuInstanceGetAccountingMode(vgpuInstance): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetAccountingMode") + ret = fn(vgpuInstance, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + +def nvmlVgpuInstanceGetAccountingPids(vgpuInstance): + c_pidCount = c_uint() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetAccountingPids") + ret = fn(vgpuInstance, byref(c_pidCount), None) + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + sampleArray = c_pidCount.value * c_uint + c_pidArray = sampleArray() + ret = fn(vgpuInstance, byref(c_pidCount), byref(c_pidArray)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return (c_pidCount, c_pidArray) + +def nvmlVgpuInstanceGetAccountingStats(vgpuInstance, pid): + c_accountingStats = c_nvmlAccountingStats_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetAccountingStats") + ret = fn(vgpuInstance, pid, byref(c_accountingStats)) + _nvmlCheckReturn(ret) + return c_accountingStats + +def nvmlVgpuInstanceClearAccountingPids(vgpuInstance): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceClearAccountingPids") + ret = fn(vgpuInstance) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGetExcludedDeviceCount(): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGetExcludedDeviceCount") + ret = fn(byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlGetExcludedDeviceInfoByIndex(index): + c_index = c_uint(index) + info = c_nvmlExcludedDeviceInfo_t() + fn = _nvmlGetFunctionPointer("nvmlGetExcludedDeviceInfoByIndex") + ret = fn(c_index, byref(info)) + _nvmlCheckReturn(ret) + return info + +def nvmlDeviceGetHostVgpuMode(handle): + c_host_vgpu_mode = _nvmlHostVgpuMode_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHostVgpuMode") + ret = fn(handle, byref(c_host_vgpu_mode)) + _nvmlCheckReturn(ret) + return c_host_vgpu_mode.value + +def nvmlDeviceSetMigMode(device, mode): + c_activationStatus = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceSetMigMode") + ret = fn(device, mode, byref(c_activationStatus)) + _nvmlCheckReturn(ret) + return c_activationStatus.value + +def nvmlDeviceGetMigMode(device): + c_currentMode = c_uint() + c_pendingMode = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMigMode") + ret = fn(device, byref(c_currentMode), byref(c_pendingMode)) + _nvmlCheckReturn(ret) + return [c_currentMode.value, c_pendingMode.value] + +def nvmlDeviceGetGpuInstanceProfileInfo(device, profile, version=2): + if version == 2: + c_info = c_nvmlGpuInstanceProfileInfo_v2_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceProfileInfoV") + elif version == 1: + c_info = c_nvmlGpuInstanceProfileInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceProfileInfo") + else: + raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND) + ret = fn(device, profile, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +# Define function alias for the API exposed by NVML +nvmlDeviceGetGpuInstanceProfileInfoV = nvmlDeviceGetGpuInstanceProfileInfo + +def nvmlDeviceGetGpuInstanceRemainingCapacity(device, profileId): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceRemainingCapacity") + ret = fn(device, profileId, byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlDeviceGetGpuInstancePossiblePlacements(device, profileId, placementsRef, countRef): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstancePossiblePlacements_v2") + ret = fn(device, profileId, placementsRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceCreateGpuInstance(device, profileId): + c_instance = c_nvmlGpuInstance_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceCreateGpuInstance") + ret = fn(device, profileId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlDeviceCreateGpuInstanceWithPlacement(device, profileId, placement): + c_instance = c_nvmlGpuInstance_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceCreateGpuInstanceWithPlacement") + ret = fn(device, profileId, placement, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlGpuInstanceDestroy(gpuInstance): + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceDestroy") + ret = fn(gpuInstance) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetGpuInstances(device, profileId, gpuInstancesRef, countRef): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstances") + ret = fn(device, profileId, gpuInstancesRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetGpuInstanceById(device, gpuInstanceId): + c_instance = c_nvmlGpuInstance_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceById") + ret = fn(device, gpuInstanceId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlGpuInstanceGetInfo(gpuInstance): + c_info = c_nvmlGpuInstanceInfo_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetInfo") + ret = fn(gpuInstance, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlGpuInstanceGetComputeInstanceProfileInfo(device, profile, engProfile, version=2): + if version == 2: + c_info = c_nvmlComputeInstanceProfileInfo_v2_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceProfileInfoV") + elif version == 1: + c_info = c_nvmlComputeInstanceProfileInfo_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceProfileInfo") + else: + raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND) + ret = fn(device, profile, engProfile, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +# Define function alias for the API exposed by NVML +nvmlGpuInstanceGetComputeInstanceProfileInfoV = nvmlGpuInstanceGetComputeInstanceProfileInfo + +def nvmlGpuInstanceGetComputeInstanceRemainingCapacity(gpuInstance, profileId): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceRemainingCapacity") + ret = fn(gpuInstance, profileId, byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlGpuInstanceGetComputeInstancePossiblePlacements(gpuInstance, profileId, placementsRef, countRef): + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstancePossiblePlacements") + ret = fn(gpuInstance, profileId, placementsRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGpuInstanceCreateComputeInstance(gpuInstance, profileId): + c_instance = c_nvmlComputeInstance_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceCreateComputeInstance") + ret = fn(gpuInstance, profileId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlGpuInstanceCreateComputeInstanceWithPlacement(gpuInstance, profileId, placement): + c_instance = c_nvmlComputeInstance_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceCreateComputeInstanceWithPlacement") + ret = fn(gpuInstance, profileId, placement, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlComputeInstanceDestroy(computeInstance): + fn = _nvmlGetFunctionPointer("nvmlComputeInstanceDestroy") + ret = fn(computeInstance) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGpuInstanceGetComputeInstances(gpuInstance, profileId, computeInstancesRef, countRef): + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstances") + ret = fn(gpuInstance, profileId, computeInstancesRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGpuInstanceGetComputeInstanceById(gpuInstance, computeInstanceId): + c_instance = c_nvmlComputeInstance_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceById") + ret = fn(gpuInstance, computeInstanceId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlComputeInstanceGetInfo_v2(computeInstance): + c_info = c_nvmlComputeInstanceInfo_t() + fn = _nvmlGetFunctionPointer("nvmlComputeInstanceGetInfo_v2") + ret = fn(computeInstance, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlComputeInstanceGetInfo(computeInstance): + return nvmlComputeInstanceGetInfo_v2(computeInstance) + +def nvmlDeviceIsMigDeviceHandle(device): + c_isMigDevice = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceIsMigDeviceHandle") + ret = fn(device, byref(c_isMigDevice)) + _nvmlCheckReturn(ret) + return c_isMigDevice + +def nvmlDeviceGetGpuInstanceId(device): + c_gpuInstanceId = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceId") + ret = fn(device, byref(c_gpuInstanceId)) + _nvmlCheckReturn(ret) + return c_gpuInstanceId.value + +def nvmlDeviceGetComputeInstanceId(device): + c_computeInstanceId = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeInstanceId") + ret = fn(device, byref(c_computeInstanceId)) + _nvmlCheckReturn(ret) + return c_computeInstanceId.value + +def nvmlDeviceGetMaxMigDeviceCount(device): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxMigDeviceCount") + ret = fn(device, byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlDeviceGetMigDeviceHandleByIndex(device, index): + c_index = c_uint(index) + migDevice = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMigDeviceHandleByIndex") + ret = fn(device, c_index, byref(migDevice)) + _nvmlCheckReturn(ret) + return migDevice + +def nvmlDeviceGetDeviceHandleFromMigDeviceHandle(migDevice): + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDeviceHandleFromMigDeviceHandle") + ret = fn(migDevice, byref(device)) + _nvmlCheckReturn(ret) + return device + +def nvmlDeviceGetAttributes_v2(device): + c_attrs = c_nvmlDeviceAttributes() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAttributes_v2") + ret = fn(device, byref(c_attrs)) + _nvmlCheckReturn(ret) + return c_attrs + +def nvmlDeviceGetAttributes(device): + return nvmlDeviceGetAttributes_v2(device) + +def nvmlDeviceGetRemappedRows(device): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRemappedRows") + c_corr = c_uint() + c_unc = c_uint() + c_bpending = c_uint() + c_bfailure = c_uint() + ret = fn(device, byref(c_corr), byref(c_unc), byref(c_bpending), byref(c_bfailure)) + _nvmlCheckReturn(ret) + return (c_corr.value, c_unc.value, c_bpending.value, c_bfailure.value) + +def nvmlDeviceGetRowRemapperHistogram(device): + c_vals = c_nvmlRowRemapperHistogramValues() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRowRemapperHistogram") + ret = fn(device, byref(c_vals)) + _nvmlCheckReturn(ret) + return c_vals + +def nvmlDeviceGetArchitecture(device): + arch = _nvmlDeviceArchitecture_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetArchitecture") + ret = fn(device, byref(arch)) + _nvmlCheckReturn(ret) + return arch.value + +def nvmlDeviceGetBusType(device): + c_busType = _nvmlBusType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBusType") + ret = fn(device, byref(c_busType)) + _nvmlCheckReturn(ret) + return c_busType.value + +def nvmlDeviceGetIrqNum(device): + c_irqNum = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetIrqNum") + ret = fn(device, byref(c_irqNum)) + _nvmlCheckReturn(ret) + return c_irqNum.value + +def nvmlDeviceGetNumGpuCores(device): + c_numCores = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNumGpuCores") + ret = fn(device, byref(c_numCores)) + _nvmlCheckReturn(ret) + return c_numCores.value + +def nvmlDeviceGetPowerSource(device): + c_powerSource = _nvmlPowerSource_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerSource") + ret = fn(device, byref(c_powerSource)) + _nvmlCheckReturn(ret) + return c_powerSource.value + +def nvmlDeviceGetMemoryBusWidth(device): + c_memBusWidth = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryBusWidth") + ret = fn(device, byref(c_memBusWidth)) + _nvmlCheckReturn(ret) + return c_memBusWidth.value + +def nvmlDeviceGetPcieLinkMaxSpeed(device): + c_speed = _nvmlPcieLinkMaxSpeed_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieLinkMaxSpeed") + ret = fn(device, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + +def nvmlDeviceGetAdaptiveClockInfoStatus(device): + c_adaptiveClockInfoStatus = _nvmlAdaptiveClockInfoStatus_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAdaptiveClockInfoStatus") + ret = fn(device, byref(c_adaptiveClockInfoStatus)) + _nvmlCheckReturn(ret) + return c_adaptiveClockInfoStatus.value + +def nvmlDeviceGetPcieSpeed(device): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieSpeed") + ret = fn(device, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + +def nvmlDeviceGetDynamicPstatesInfo(device, c_dynamicpstatesinfo=c_nvmlGpuDynamicPstatesInfo_t()): + isReference = type(c_dynamicpstatesinfo) is not c_nvmlGpuDynamicPstatesInfo_t + dynamicpstatesinfoRef = c_dynamicpstatesinfo if isReference else byref(c_dynamicpstatesinfo) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDynamicPstatesInfo"); + ret = fn(device, dynamicpstatesinfoRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else c_dynamicpstatesinfo + +def nvmlDeviceSetFanSpeed_v2(handle, index, speed): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetFanSpeed_v2"); + ret = fn(handle, index, speed) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetThermalSettings(device, sensorindex, c_thermalsettings=c_nvmlGpuThermalSettings_t()): + isReference = type(c_thermalsettings) is not c_nvmlGpuThermalSettings_t + thermalsettingsRef = c_thermalsettings if isReference else byref(c_thermalsettings) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetThermalSettings"); + ret = fn(device, sensorindex, thermalsettingsRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else c_thermalsettings.sensor[:] + +def nvmlDeviceGetMinMaxClockOfPState(device, clockType, pstate, minClockMHz=c_uint(), maxClockMHz=c_uint()): + isReference = (type(minClockMHz) is not c_uint) or (type(maxClockMHz) is not c_uint) + minClockMHzRef = minClockMHz if isReference else byref(minClockMHz) + maxClockMHzRef = maxClockMHz if isReference else byref(maxClockMHz) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMinMaxClockOfPState"); + ret = fn(device, _nvmlClockType_t(clockType), _nvmlClockType_t(pstate), minClockMHzRef, maxClockMHzRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else (minClockMHz.value, maxClockMHz.value) + +class c_nvmlClockOffset_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('type', _nvmlClockType_t), + ('pstate', _nvmlPstates_t), + ('clockOffsetMHz', c_int), + ('minClockOffsetMHz', c_int), + ('maxClockOffsetMHz', c_int), + ] + +nvmlClockOffset_v1 = 0x1000018 + +def nvmlDeviceGetClockOffsets(device, info): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClockOffsets"); + ret = fn(device, info) + return NVML_SUCCESS + +def nvmlDeviceSetClockOffsets(device, info): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetClockOffsets"); + ret = fn(device, info) + return NVML_SUCCESS + +def nvmlDeviceGetSupportedPerformanceStates(device): + pstates = [] + c_count = c_uint(NVML_MAX_GPU_PERF_PSTATES) + c_size = sizeof(c_uint)*c_count.value + + # NOTE: use 'c_uint' to represent the size of the nvmlPstate_t enumeration. + pstates_array = _nvmlPstates_t * c_count.value + c_pstates = pstates_array() + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedPerformanceStates") + ret = fn(device, c_pstates, c_size) + _nvmlCheckReturn(ret) + + for value in c_pstates: + if value != NVML_PSTATE_UNKNOWN: + pstates.append(value) + + return pstates + +def nvmlDeviceGetGpcClkVfOffset(device): + offset = c_int32() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpcClkVfOffset") + ret = fn(device, byref(offset)) + _nvmlCheckReturn(ret) + return offset.value + +def nvmlDeviceSetGpcClkVfOffset(device, offset): + c_offset = c_int32(offset) + fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpcClkVfOffset") + ret = fn(device, c_offset) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetGpcClkMinMaxVfOffset(device, minOffset=c_int(), maxOffset=c_int()): + isReference = (type(minOffset) is not c_int) or (type(maxOffset) is not c_int) + minOffsetRef = minOffset if isReference else byref(minOffset) + maxOffsetRef = maxOffset if isReference else byref(maxOffset) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpcClkMinMaxVfOffset") + ret = fn(device, minOffsetRef, maxOffsetRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else (minOffset.value, maxOffset.value) + +def nvmlDeviceGetMemClkVfOffset(device): + offset = c_int32() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemClkVfOffset") + ret = fn(device, byref(offset)) + _nvmlCheckReturn(ret) + return offset.value + +def nvmlDeviceSetMemClkVfOffset(device, offset): + c_offset = c_int32(offset) + fn = _nvmlGetFunctionPointer("nvmlDeviceSetMemClkVfOffset") + ret = fn(device, c_offset) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetMemClkMinMaxVfOffset(device, minOffset=c_int(), maxOffset=c_int()): + isReference = (type(minOffset) is not c_int) or (type(maxOffset) is not c_int) + minOffsetRef = minOffset if isReference else byref(minOffset) + maxOffsetRef = maxOffset if isReference else byref(maxOffset) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemClkMinMaxVfOffset") + ret = fn(device, minOffsetRef, maxOffsetRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else (minOffset.value, maxOffset.value) + +def nvmlSystemSetConfComputeGpusReadyState(state): + c_state = c_uint(state) + fn = _nvmlGetFunctionPointer("nvmlSystemSetConfComputeGpusReadyState") + ret = fn(c_state) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlSystemGetConfComputeGpusReadyState(): + c_state = c_uint() + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeGpusReadyState") + ret = fn(byref(c_state)) + _nvmlCheckReturn(ret) + return c_state.value + +def nvmlSystemGetConfComputeCapabilities(): + c_ccSysCaps = c_nvmlConfComputeSystemCaps_t() + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeCapabilities") + ret = fn(byref(c_ccSysCaps)) + _nvmlCheckReturn(ret) + return c_ccSysCaps + +def nvmlSystemGetConfComputeState(): + c_state = c_nvmlConfComputeSystemState_t() + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeState") + ret = fn(byref(c_state)) + _nvmlCheckReturn(ret) + return c_state + +def nvmlSystemGetConfComputeSettings(settings): + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeSettings") + return fn(settings) + +def nvmlDeviceSetConfComputeUnprotectedMemSize(device, c_ccMemSize): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetConfComputeUnprotectedMemSize") + ret = fn(device, c_ccMemSize) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetConfComputeMemSizeInfo(device): + c_ccMemSize = c_nvmlConfComputeMemSizeInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeMemSizeInfo") + ret = fn(device, byref(c_ccMemSize)) + _nvmlCheckReturn(ret) + return c_ccMemSize + +def nvmlDeviceGetConfComputeProtectedMemoryUsage(device): + c_memory = c_nvmlMemory_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeProtectedMemoryUsage") + ret = fn(device, byref(c_memory)) + _nvmlCheckReturn(ret) + return c_memory + +def nvmlDeviceGetConfComputeGpuCertificate(device): + c_cert = c_nvmlConfComputeGpuCertificate_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeGpuCertificate") + ret = fn(device, byref(c_cert)) + _nvmlCheckReturn(ret) + return c_cert + +def nvmlDeviceGetConfComputeGpuAttestationReport(device, c_nonce): + c_attestReport = c_nvmlConfComputeGpuAttestationReport_t() + c_nonce_arr = (c_uint8 * len(c_nonce))(*(c_nonce)) + setattr(c_attestReport, 'nonce', c_nonce_arr) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeGpuAttestationReport") + ret = fn(device, byref(c_attestReport)) + _nvmlCheckReturn(ret) + return c_attestReport + +def nvmlSystemSetConfComputeKeyRotationThresholdInfo(max_atk_adv): + c_keyRotationThrInfo = c_nvmlConfComputeSetKeyRotationThresholdInfo_t(0) + c_keyRotationThrInfo.version = ConfComputeSetKeyRotationThresholdInfo_v1 + c_keyRotationThrInfo.maxAttackerAdvantage = max_atk_adv + fn = _nvmlGetFunctionPointer("nvmlSystemSetConfComputeKeyRotationThresholdInfo") + ret = fn(byref(c_keyRotationThrInfo)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlSystemGetConfComputeKeyRotationThresholdInfo(): + c_keyRotationThrInfo = c_nvmlConfComputeGetKeyRotationThresholdInfo_t(0) + c_keyRotationThrInfo.version = ConfComputeGetKeyRotationThresholdInfo_v1 + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeKeyRotationThresholdInfo") + ret = fn(byref(c_keyRotationThrInfo)) + _nvmlCheckReturn(ret) + return c_keyRotationThrInfo + +## GPM ## +######### + +## Enums/defines + +#### GPM Metric Identifiers +NVML_GPM_METRIC_GRAPHICS_UTIL = 1 # Percentage of time any compute/graphics app was active on the GPU. 0.0 - 100.0 +NVML_GPM_METRIC_SM_UTIL = 2 # Percentage of SMs that were busy. 0.0 - 100.0 +NVML_GPM_METRIC_SM_OCCUPANCY = 3 # Percentage of warps that were active vs theoretical maximum. 0.0 - 100.0 +NVML_GPM_METRIC_INTEGER_UTIL = 4 # Percentage of time the GPU's SMs were doing integer operations. 0.0 - 100.0 +NVML_GPM_METRIC_ANY_TENSOR_UTIL = 5 # Percentage of time the GPU's SMs were doing ANY tensor operations. 0.0 - 100.0 +NVML_GPM_METRIC_DFMA_TENSOR_UTIL = 6 # Percentage of time the GPU's SMs were doing DFMA tensor operations. 0.0 - 100.0 +NVML_GPM_METRIC_HMMA_TENSOR_UTIL = 7 # Percentage of time the GPU's SMs were doing HMMA tensor operations. 0.0 - 100.0 +NVML_GPM_METRIC_IMMA_TENSOR_UTIL = 9 # Percentage of time the GPU's SMs were doing IMMA tensor operations. 0.0 - 100.0 +NVML_GPM_METRIC_DRAM_BW_UTIL = 10 # Percentage of DRAM bw used vs theoretical maximum. 0.0 - 100.0 +NVML_GPM_METRIC_FP64_UTIL = 11 # Percentage of time the GPU's SMs were doing non-tensor FP64 math. 0.0 - 100.0 +NVML_GPM_METRIC_FP32_UTIL = 12 # Percentage of time the GPU's SMs were doing non-tensor FP32 math. 0.0 - 100.0 +NVML_GPM_METRIC_FP16_UTIL = 13 # Percentage of time the GPU's SMs were doing non-tensor FP16 math. 0.0 - 100.0 +NVML_GPM_METRIC_PCIE_TX_PER_SEC = 20 # PCIe traffic from this GPU in MiB/sec +NVML_GPM_METRIC_PCIE_RX_PER_SEC = 21 # PCIe traffic to this GPU in MiB/sec +NVML_GPM_METRIC_NVDEC_0_UTIL = 30 # Percent utilization of NVDEC 0. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_1_UTIL = 31 # Percent utilization of NVDEC 1. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_2_UTIL = 32 # Percent utilization of NVDEC 2. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_3_UTIL = 33 # Percent utilization of NVDEC 3. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_4_UTIL = 34 # Percent utilization of NVDEC 4. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_5_UTIL = 35 # Percent utilization of NVDEC 5. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_6_UTIL = 36 # Percent utilization of NVDEC 6. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_7_UTIL = 37 # Percent utilization of NVDEC 7. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_0_UTIL = 40 # Percent utilization of NVJPG 0. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_1_UTIL = 41 # Percent utilization of NVJPG 1. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_2_UTIL = 42 # Percent utilization of NVJPG 2. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_3_UTIL = 43 # Percent utilization of NVJPG 3. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_4_UTIL = 44 # Percent utilization of NVJPG 4. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_5_UTIL = 45 # Percent utilization of NVJPG 5. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_6_UTIL = 46 # Percent utilization of NVJPG 6. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_7_UTIL = 47 # Percent utilization of NVJPG 7. 0.0 - 100.0 +NVML_GPM_METRIC_NVOFA_0_UTIL = 50 # Percent utilization of NVOFA 0. 0.0 - 100.0 +NVML_GPM_METRIC_NVOFA_1_UTIL = 51 # Percent utilization of NVOFA 1. 0.0 - 100.0 +NVML_GPM_METRIC_NVLINK_TOTAL_RX_PER_SEC = 60 # NvLink read bandwidth for all links in MiB/sec +NVML_GPM_METRIC_NVLINK_TOTAL_TX_PER_SEC = 61 # NvLink write bandwidth for all links in MiB/sec +NVML_GPM_METRIC_NVLINK_L0_RX_PER_SEC = 62 # NvLink read bandwidth for link 0 in MiB/sec +NVML_GPM_METRIC_NVLINK_L0_TX_PER_SEC = 63 # NvLink write bandwidth for link 0 in MiB/sec +NVML_GPM_METRIC_NVLINK_L1_RX_PER_SEC = 64 # NvLink read bandwidth for link 1 in MiB/sec +NVML_GPM_METRIC_NVLINK_L1_TX_PER_SEC = 65 # NvLink write bandwidth for link 1 in MiB/sec +NVML_GPM_METRIC_NVLINK_L2_RX_PER_SEC = 66 # NvLink read bandwidth for link 2 in MiB/sec +NVML_GPM_METRIC_NVLINK_L2_TX_PER_SEC = 67 # NvLink write bandwidth for link 2 in MiB/sec +NVML_GPM_METRIC_NVLINK_L3_RX_PER_SEC = 68 # NvLink read bandwidth for link 3 in MiB/sec +NVML_GPM_METRIC_NVLINK_L3_TX_PER_SEC = 69 # NvLink write bandwidth for link 3 in MiB/sec +NVML_GPM_METRIC_NVLINK_L4_RX_PER_SEC = 70 # NvLink read bandwidth for link 4 in MiB/sec +NVML_GPM_METRIC_NVLINK_L4_TX_PER_SEC = 71 # NvLink write bandwidth for link 4 in MiB/sec +NVML_GPM_METRIC_NVLINK_L5_RX_PER_SEC = 72 # NvLink read bandwidth for link 5 in MiB/sec +NVML_GPM_METRIC_NVLINK_L5_TX_PER_SEC = 73 # NvLink write bandwidth for link 5 in MiB/sec +NVML_GPM_METRIC_NVLINK_L6_RX_PER_SEC = 74 # NvLink read bandwidth for link 6 in MiB/sec +NVML_GPM_METRIC_NVLINK_L6_TX_PER_SEC = 75 # NvLink write bandwidth for link 6 in MiB/sec +NVML_GPM_METRIC_NVLINK_L7_RX_PER_SEC = 76 # NvLink read bandwidth for link 7 in MiB/sec +NVML_GPM_METRIC_NVLINK_L7_TX_PER_SEC = 77 # NvLink write bandwidth for link 7 in MiB/sec +NVML_GPM_METRIC_NVLINK_L8_RX_PER_SEC = 78 # NvLink read bandwidth for link 8 in MiB/sec +NVML_GPM_METRIC_NVLINK_L8_TX_PER_SEC = 79 # NvLink write bandwidth for link 8 in MiB/sec +NVML_GPM_METRIC_NVLINK_L9_RX_PER_SEC = 80 # NvLink read bandwidth for link 9 in MiB/sec +NVML_GPM_METRIC_NVLINK_L9_TX_PER_SEC = 81 # NvLink write bandwidth for link 9 in MiB/sec +NVML_GPM_METRIC_NVLINK_L10_RX_PER_SEC = 82 # NvLink read bandwidth for link 10 in MiB/sec +NVML_GPM_METRIC_NVLINK_L10_TX_PER_SEC = 83 # NvLink write bandwidth for link 10 in MiB/sec +NVML_GPM_METRIC_NVLINK_L11_RX_PER_SEC = 84 # NvLink read bandwidth for link 11 in MiB/sec +NVML_GPM_METRIC_NVLINK_L11_TX_PER_SEC = 85 # NvLink write bandwidth for link 11 in MiB/sec +NVML_GPM_METRIC_NVLINK_L12_RX_PER_SEC = 86 # NvLink read bandwidth for link 12 in MiB/sec +NVML_GPM_METRIC_NVLINK_L12_TX_PER_SEC = 87 # NvLink write bandwidth for link 12 in MiB/sec +NVML_GPM_METRIC_NVLINK_L13_RX_PER_SEC = 88 # NvLink read bandwidth for link 13 in MiB/sec +NVML_GPM_METRIC_NVLINK_L13_TX_PER_SEC = 89 # NvLink write bandwidth for link 13 in MiB/sec +NVML_GPM_METRIC_NVLINK_L14_RX_PER_SEC = 90 # NvLink read bandwidth for link 14 in MiB/sec +NVML_GPM_METRIC_NVLINK_L14_TX_PER_SEC = 91 # NvLink write bandwidth for link 14 in MiB/sec +NVML_GPM_METRIC_NVLINK_L15_RX_PER_SEC = 92 # NvLink read bandwidth for link 15 in MiB/sec +NVML_GPM_METRIC_NVLINK_L15_TX_PER_SEC = 93 # NvLink write bandwidth for link 15 in MiB/sec +NVML_GPM_METRIC_NVLINK_L16_RX_PER_SEC = 94 # NvLink read bandwidth for link 16 in MiB/sec +NVML_GPM_METRIC_NVLINK_L16_TX_PER_SEC = 95 # NvLink write bandwidth for link 16 in MiB/sec +NVML_GPM_METRIC_NVLINK_L17_RX_PER_SEC = 96 # NvLink read bandwidth for link 17 in MiB/sec +NVML_GPM_METRIC_NVLINK_L17_TX_PER_SEC = 97 # NvLink write bandwidth for link 17 in MiB/sec +NVML_GPM_METRIC_MAX = 98 + +## Structs + +class c_nvmlUnitInfo_t(_PrintableStructure): + _fields_ = [ + ('name', c_char * 96), + ('id', c_char * 96), + ('serial', c_char * 96), + ('firmwareVersion', c_char * 96), + ] + +class struct_c_nvmlGpmSample_t(Structure): + pass # opaque handle +c_nvmlGpmSample_t = POINTER(struct_c_nvmlGpmSample_t) + +class c_metricInfo_t(Structure): + _fields_ = [ + ("shortName", c_char_p), + ("longName", c_char_p), + ("unit", c_char_p), + ] + +class c_nvmlGpmMetric_t(_PrintableStructure): + _fields_ = [ + ('metricId', c_uint), + ('nvmlReturn', _nvmlReturn_t), + ('value', c_double), + ('metricInfo', c_metricInfo_t) + ] + +class c_nvmlGpmMetricsGet_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('numMetrics', c_uint), + ('sample1', c_nvmlGpmSample_t), + ('sample2', c_nvmlGpmSample_t), + ('metrics', c_nvmlGpmMetric_t * NVML_GPM_METRIC_MAX) + ] + +NVML_GPM_METRICS_GET_VERSION = 1 + +class c_nvmlGpmSupport_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('isSupportedDevice', c_uint), + ] + +NVML_GPM_SUPPORT_VERSION = 1 + +## Functions + +def nvmlGpmMetricsGet(metricsGet): + fn = _nvmlGetFunctionPointer("nvmlGpmMetricsGet") + ret = fn(byref(metricsGet)) + _nvmlCheckReturn(ret) + return metricsGet + +def nvmlGpmSampleFree(gpmSample): + fn = _nvmlGetFunctionPointer("nvmlGpmSampleFree") + ret = fn(gpmSample) + _nvmlCheckReturn(ret) + return + +def nvmlGpmSampleAlloc(): + gpmSample = c_nvmlGpmSample_t() + fn = _nvmlGetFunctionPointer("nvmlGpmSampleAlloc") + ret = fn(byref(gpmSample)) + _nvmlCheckReturn(ret) + return gpmSample + +def nvmlGpmSampleGet(device, gpmSample): + fn = _nvmlGetFunctionPointer("nvmlGpmSampleGet") + ret = fn(device, gpmSample) + _nvmlCheckReturn(ret) + return gpmSample + +def nvmlGpmMigSampleGet(device, gpuInstanceId, gpmSample): + fn = _nvmlGetFunctionPointer("nvmlGpmMigSampleGet") + ret = fn(device, gpuInstanceId, gpmSample) + _nvmlCheckReturn(ret) + return gpmSample + +def nvmlGpmQueryDeviceSupport(device): + gpmSupport = c_nvmlGpmSupport_t() + gpmSupport.version = NVML_GPM_SUPPORT_VERSION + fn = _nvmlGetFunctionPointer("nvmlGpmQueryDeviceSupport") + ret = fn(device, byref(gpmSupport)) + _nvmlCheckReturn(ret) + return gpmSupport + +def nvmlGpmSetStreamingEnabled(device, state): + c_state = c_uint(state) + fn = _nvmlGetFunctionPointer("nvmlGpmSetStreamingEnabled") + ret = fn(device, c_state) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGpmQueryIfStreamingEnabled(device): + c_state = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGpmQueryIfStreamingEnabled") + ret = fn(device, byref(c_state)) + _nvmlCheckReturn(ret) + return c_state.value + +# Low Power Structure and Function + +NVML_NVLINK_POWER_STATE_HIGH_SPEED = 0x0 +NVML_NVLINK_POWER_STATE_LOW = 0x1 + +NVML_NVLINK_LOW_POWER_THRESHOLD_MIN = 0x1 +NVML_NVLINK_LOW_POWER_THRESHOLD_MAX = 0x1FFF +NVML_NVLINK_LOW_POWER_THRESHOLD_RESET = 0xFFFFFFFF +NVML_NVLINK_LOW_POWER_THRESHOLD_DEFAULT = NVML_NVLINK_LOW_POWER_THRESHOLD_RESET + +class c_nvmlNvLinkPowerThres_t(Structure): + _fields_ = [ + ("lowPwrThreshold", c_uint), + ] + +def nvmlDeviceSetNvLinkDeviceLowPowerThreshold(device, l1threshold): + c_info = c_nvmlNvLinkPowerThres_t() + c_info.lowPwrThreshold = l1threshold + fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvLinkDeviceLowPowerThreshold") + ret = fn(device, byref(c_info)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +NVML_GPU_FABRIC_UUID_LEN = 16 + +_nvmlGpuFabricState_t = c_uint +NVML_GPU_FABRIC_STATE_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_STATE_NOT_STARTED = 1 +NVML_GPU_FABRIC_STATE_IN_PROGRESS = 2 +NVML_GPU_FABRIC_STATE_COMPLETED = 3 + +class c_nvmlGpuFabricInfo_t(_PrintableStructure): + _fields_ = [ + ("clusterUuid", c_char * NVML_DEVICE_UUID_BUFFER_SIZE), + ("status", _nvmlReturn_t), + ("cliqueId", c_uint32), + ("state", _nvmlGpuFabricState_t) + ] + +NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_DEGRADED_BW = 0 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_DEGRADED_BW = 0x11 + +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ROUTE_RECOVERY = 2 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ROUTE_RECOVERY = 0x11 + +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ROUTE_UNHEALTHY = 4 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ROUTE_UNHEALTHY = 0x11 + +NVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ACCESS_TIMEOUT_RECOVERY = 6 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ACCESS_TIMEOUT_RECOVERY = 0x11 + +nvmlGpuFabricInfo_v2 = 0x02000024 + +class c_nvmlGpuFabricInfoV_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("clusterUuid", c_char * NVML_GPU_FABRIC_UUID_LEN), + ("status", _nvmlReturn_t), + ("cliqueId", c_uint32), + ("state", _nvmlGpuFabricState_t), + ("healthMask", c_uint32) + ] + + def __init__(self): + super(c_nvmlGpuFabricInfoV_t, self).__init__(version=nvmlGpuFabricInfo_v2) + +def nvmlDeviceGetGpuFabricInfo(device, gpuFabricInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuFabricInfo"); + ret = fn(device, gpuFabricInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetGpuFabricInfoV(device, gpuFabricInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuFabricInfoV"); + ret = fn(device, gpuFabricInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +###################### +## Enums/defines +#### NVML GPU NVLINK BW MODE +NVML_GPU_NVLINK_BW_MODE_FULL = 0x0 +NVML_GPU_NVLINK_BW_MODE_OFF = 0x1 +NVML_GPU_NVLINK_BW_MODE_MIN = 0x2 +NVML_GPU_NVLINK_BW_MODE_HALF = 0x3 +NVML_GPU_NVLINK_BW_MODE_3QUARTER = 0x4 +NVML_GPU_NVLINK_BW_MODE_COUNT = 0x5 + +def nvmlSystemSetNvlinkBwMode(mode): + fn = _nvmlGetFunctionPointer("nvmlSystemSetNvlinkBwMode") + ret = fn(mode) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlSystemGetNvlinkBwMode(): + mode = c_uint() + fn = _nvmlGetFunctionPointer("nvmlSystemGetNvlinkBwMode") + ret = fn(byref(mode)) + _nvmlCheckReturn(ret) + return mode.value + +_nvmlPowerScopeType_t = c_uint +NVML_POWER_SCOPE_GPU = 0 +NVML_POWER_SCOPE_MODULE = 1 +NVML_POWER_SCOPE_MEMORY = 2 + +class c_nvmlPowerValue_v2_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('powerScope', _nvmlPowerScopeType_t), + ('powerValueMw', c_uint), + ] + _fmt_ = {'': "%d B"} + +nvmlPowerValue_v2 = 0x0200000C + +def nvmlDeviceSetPowerManagementLimit_v2(device, powerScope, powerLimit, version=nvmlPowerValue_v2): + c_powerScope = _nvmlPowerScopeType_t(powerScope) + c_powerValue = c_nvmlPowerValue_v2_t() + c_powerValue.version = c_uint(version) + c_powerValue.powerScope = c_powerScope + c_powerValue.powerValueMw = c_uint(powerLimit) + fn = _nvmlGetFunctionPointer("nvmlDeviceSetPowerManagementLimit_v2") + ret = fn(device, byref(c_powerValue)) + return NVML_SUCCESS + +class c_nvmlEccSramErrorStatus_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('aggregateUncParity', c_ulonglong), + ('aggregateUncSecDed', c_ulonglong), + ('aggregateCor', c_ulonglong), + ('volatileUncParity', c_ulonglong), + ('volatileUncSecDed', c_ulonglong), + ('volatileCor', c_ulonglong), + ('aggregateUncBucketL2', c_ulonglong), + ('aggregateUncBucketSm', c_ulonglong), + ('aggregateUncBucketPcie', c_ulonglong), + ('aggregateUncBucketMcu', c_ulonglong), + ('aggregateUncBucketOther', c_ulonglong), + ('bThresholdExceeded', c_uint) + ] + + def __init__(self): + super(c_nvmlEccSramErrorStatus_v1_t, self).__init__(version=nvmlEccSramErrorStatus_v1) + +nvmlEccSramErrorStatus_v1 = 0x1000068 +def nvmlDeviceGetSramEccErrorStatus(device, status): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSramEccErrorStatus") + ret = fn(device, status) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +NVML_DEV_CAP_EGM = (1 << 0) +nvmlDeviceCapabilities_v1 = 0x1000008 + +class c_nvmlDeviceCapabilities_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('capMask', c_uint), + ] + + def __init__(self): + super(c_nvmlDeviceCapabilities_v1_t, self).__init__(version=nvmlDeviceCapabilities_v1) + + +def nvmlDeviceGetCapabilities(device, caps): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCapabilities") + return fn(device, caps) + +class c_nvmlPlatformInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('ibGuid', c_char * 16), + ('rackGuid', c_char * 16), + ('chassisPhysicalSlotNumber', c_char), + ('computeSlotIndex', c_char), + ('nodeIndex', c_char), + ('peerType', c_char), + ('moduleId', c_char) + ] + + def __init__(self): + super(c_nvmlPlatformInfo_v1_t, self).__init__(version=nvmlPlatformInfo_v1) + +nvmlPlatformInfo_v1 = 0x100002c +def nvmlDeviceGetPlatformInfo(device, platformInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPlatformInfo") + ret = fn(device, platformInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +class c_nvmlMask255_t(_PrintableStructure): + _fields_ = [ + ('mask', c_uint * 8), + ] + +NVML_WORKLOAD_POWER_MAX_PROFILES = 255 +NVML_POWER_PROFILE_MAX_P = 0 +NVML_POWER_PROFILE_MAX_Q = 1 +NVML_POWER_PROFILE_COMPUTE = 2 +NVML_POWER_PROFILE_MEMORY_BOUND = 3 +NVML_POWER_PROFILE_NETWORK = 4 +NVML_POWER_PROFILE_BALANCED = 5 +NVML_POWER_PROFILE_LLM_INFERENCE = 6 +NVML_POWER_PROFILE_LLM_TRAINING = 7 +NVML_POWER_PROFILE_RBM = 8 +NVML_POWER_PROFILE_DCPCIE = 9 +NVML_POWER_PROFILE_HMMA_SPARSE = 10 +NVML_POWER_PROFILE_HMMA_DENSE = 11 +NVML_POWER_PROFILE_SYNC_BALANCED = 12 +NVML_POWER_PROFILE_HPC = 13 +NVML_POWER_PROFILE_MIG = 14 +NVML_POWER_PROFILE_MAX = 15 + +nvmlWorkloadPowerProfileInfo_v1 = 0x100002c +class c_nvmlWorkloadPowerProfileInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('profileId', c_uint), + ('priority', c_uint), + ('conflictingmask', c_nvmlMask255_t) + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileInfo_v1_t, self).__init__(version=nvmlWorkloadPowerProfileInfo_v1) + +nvmlWorkloadPowerProfileProfilesInfo_v1 = 0x1002bf8 +class c_nvmlWorkloadPowerProfileProfilesInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('perfProfilesMask', c_nvmlMask255_t), + ('perfProfile', c_nvmlWorkloadPowerProfileInfo_v1_t * NVML_WORKLOAD_POWER_MAX_PROFILES) + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileProfilesInfo_v1_t, self).__init__(version=nvmlWorkloadPowerProfileProfilesInfo_v1) + +nvmlWorkloadPowerProfileCurrentProfiles_v1 = 0x1000064 +class c_nvmlWorkloadPowerProfileCurrentProfiles_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('perfProfilesMask', c_nvmlMask255_t), + ('requestedProfilesMask', c_nvmlMask255_t), + ('enforcedProfilesMask', c_nvmlMask255_t) + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileCurrentProfiles_v1_t, self).__init__(version=nvmlWorkloadPowerProfileCurrentProfiles_v1) + +nvmlWorkloadPowerProfileRequestedProfiles_v1 = 0x1000024 +class c_nvmlWorkloadPowerProfileRequestedProfiles_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('requestedProfilesMask', c_nvmlMask255_t), + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileRequestedProfiles_v1_t, self).__init__(version=nvmlWorkloadPowerProfileRequestedProfiles_v1) + +def nvmlDeviceWorkloadPowerProfileGetProfilesInfo(device, profilesInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileGetProfilesInfo") + ret = fn(device, profilesInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceWorkloadPowerProfileGetCurrentProfiles(device, currentProfiles): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileGetCurrentProfiles") + ret = fn(device, currentProfiles) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceWorkloadPowerProfileSetRequestedProfiles(device, requestedProfiles): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileSetRequestedProfiles") + ret = fn(device, requestedProfiles) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceWorkloadPowerProfileClearRequestedProfiles(device, requestedProfiles): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileClearRequestedProfiles") + ret = fn(device, requestedProfiles) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetNvlinkSupportedBwModes(device, supportedBwModes): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvlinkSupportedBwModes") + ret = fn(device, supportedBwModes) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetNvlinkBwMode(device, getBwMode): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvlinkBwMode") + ret = fn(device, getBwMode) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceSetNvlinkBwMode(device, setBwMode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvlinkBwMode") + ret = fn(device, setBwMode) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +nvmlDramEncryptionInfo_v1 = 0x01000008 + +class c_nvmlDramEncryptionInfo_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('encryptionState', _nvmlEnableState_t), + ] + + def __init__(self): + super(c_nvmlDramEncryptionInfo_t, self).__init__(version=nvmlDramEncryptionInfo_v1) + +def nvmlDeviceGetDramEncryptionMode(handle): + c_currState = c_nvmlDramEncryptionInfo_t() + c_pendingState = c_nvmlDramEncryptionInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDramEncryptionMode") + ret = fn(handle, byref(c_currState), byref(c_pendingState)) + _nvmlCheckReturn(ret) + return [c_currState.encryptionState, c_pendingState.encryptionState] + +# added to API +def nvmlDeviceGetCurrentDramEncryptionMode(handle): + return nvmlDeviceGetDramEncryptionMode(handle)[0] + +# added to API +def nvmlDeviceGetPendingDramEncryptionMode(handle): + return nvmlDeviceGetDramEncryptionMode(handle)[1] + +def nvmlDeviceSetDramEncryptionMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDramEncryptionMode") + c_dramEncryptionMode = c_nvmlDramEncryptionInfo_t() + c_dramEncryptionMode.encryptionState = mode; + ret = fn(handle, byref(c_dramEncryptionMode)) + _nvmlCheckReturn(ret) + return None + +# Power Smoothing defines +NVML_POWER_SMOOTHING_MAX_NUM_PROFILES = 5 +NVML_POWER_SMOOTHING_ADMIN_OVERRIDE_NOT_SET = 0xFFFFFFFF +NVML_POWER_SMOOTHING_PROFILE_PARAM_PERCENT_TMP_FLOOR = 0 +NVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_UP_RATE = 1 +NVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_DOWN_RATE = 2 +NVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_DOWN_HYSTERESIS = 3 + +nvmlPowerSmoothingState_v1=0x1000008 +class c_nvmlPowerSmoothingState_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('state', c_uint), + ] + + def __init__(self): + super(c_nvmlPowerSmoothingState_v1_t, self).__init__(version=nvmlPowerSmoothingState_v1) + +nvmlPowerSmoothingProfile_v1=0x1000018 +class c_nvmlPowerSmoothingProfile_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('profileId', c_uint), + ('paramId', c_uint), + ('value', c_double), + ] + + def __init__(self): + super(c_nvmlPowerSmoothingProfile_v1_t, self).__init__(version=nvmlPowerSmoothingProfile_v1) + +def nvmlDevicePowerSmoothingActivatePresetProfile(device, profile): + fn = _nvmlGetFunctionPointer("nvmlDevicePowerSmoothingActivatePresetProfile") + ret = fn(device, profile) + _nvmlCheckReturn(ret) + +def nvmlDevicePowerSmoothingUpdatePresetProfileParam(device, profile): + fn = _nvmlGetFunctionPointer("nvmlDevicePowerSmoothingUpdatePresetProfileParam") + ret = fn(device, profile) + _nvmlCheckReturn(ret) + +def nvmlDevicePowerSmoothingSetState(device, state): + fn = _nvmlGetFunctionPointer("nvmlDevicePowerSmoothingSetState") + ret = fn(device, state) + _nvmlCheckReturn(ret) + diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 85056158bab54..aade28610b313 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -3,14 +3,15 @@ import enum import json import os +import time from pathlib import Path -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Literal, Optional, Type, Union import huggingface_hub from huggingface_hub import (file_exists, hf_hub_download, list_repo_files, try_to_load_from_cache) from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, - LocalEntryNotFoundError, + HFValidationError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError) from torch import nn @@ -100,15 +101,33 @@ def file_or_path_exists(model: Union[str, Path], config_name: str, # NB: file_exists will only check for the existence of the config file on # hf_hub. This will fail in offline mode. - try: - return file_exists(model, - config_name, - revision=revision, - token=HF_TOKEN) - except huggingface_hub.errors.OfflineModeIsEnabled: - # Don't raise in offline mode, all we know is that we don't have this - # file cached. - return False + + # Call HF to check if the file exists + # 2 retries and exponential backoff + max_retries = 2 + retry_delay = 2 + for attempt in range(max_retries): + try: + return file_exists(model, + config_name, + revision=revision, + token=HF_TOKEN) + except huggingface_hub.errors.OfflineModeIsEnabled: + # Don't raise in offline mode, + # all we know is that we don't have this + # file cached. + return False + except Exception as e: + logger.error( + "Error checking file existence: %s, retrying %d of %d", e, + attempt + 1, max_retries) + if attempt == max_retries - 1: + logger.error("Error checking file existence: %s", e) + raise + time.sleep(retry_delay) + retry_delay *= 2 + continue + return False def patch_rope_scaling(config: PretrainedConfig) -> None: @@ -193,10 +212,26 @@ def get_config( # raise an offline mode error to indicate to the user that they # don't have files cached and may need to go online. # This is conveniently triggered by calling file_exists(). - file_exists(model, - HF_CONFIG_NAME, - revision=revision, - token=HF_TOKEN) + + # Call HF to check if the file exists + # 2 retries and exponential backoff + max_retries = 2 + retry_delay = 2 + for attempt in range(max_retries): + try: + file_exists(model, + HF_CONFIG_NAME, + revision=revision, + token=HF_TOKEN) + except Exception as e: + logger.error( + "Error checking file existence: %s, retrying %d of %d", + e, attempt + 1, max_retries) + if attempt == max_retries: + logger.error("Error checking file existence: %s", e) + raise e + time.sleep(retry_delay) + retry_delay *= 2 raise ValueError(f"No supported config format found in {model}") @@ -265,49 +300,66 @@ def get_config( return config +def try_get_local_file(model: Union[str, Path], + file_name: str, + revision: Optional[str] = 'main') -> Optional[Path]: + file_path = Path(model) / file_name + if file_path.is_file(): + return file_path + else: + try: + cached_filepath = try_to_load_from_cache(repo_id=model, + filename=file_name, + revision=revision) + if isinstance(cached_filepath, str): + return Path(cached_filepath) + except HFValidationError: + ... + return None + + def get_hf_file_to_dict(file_name: str, model: Union[str, Path], revision: Optional[str] = 'main'): """ - Downloads a file from the Hugging Face Hub and returns + Downloads a file from the Hugging Face Hub and returns its contents as a dictionary. Parameters: - file_name (str): The name of the file to download. - model (str): The name of the model on the Hugging Face Hub. - - revision (str): The specific version of the model. + - revision (str): The specific version of the model. Returns: - - config_dict (dict): A dictionary containing + - config_dict (dict): A dictionary containing the contents of the downloaded file. """ - file_path = Path(model) / file_name - if file_or_path_exists(model=model, - config_name=file_name, - revision=revision): + file_path = try_get_local_file(model=model, + file_name=file_name, + revision=revision) - if not file_path.is_file(): - try: - hf_hub_file = hf_hub_download(model, - file_name, - revision=revision) - except (RepositoryNotFoundError, RevisionNotFoundError, - EntryNotFoundError, LocalEntryNotFoundError) as e: - logger.debug("File or repository not found in hf_hub_download", - e) - return None - except HfHubHTTPError as e: - logger.warning( - "Cannot connect to Hugging Face Hub. Skipping file " - "download for '%s':", - file_name, - exc_info=e) - return None - file_path = Path(hf_hub_file) + if file_path is None and file_or_path_exists( + model=model, config_name=file_name, revision=revision): + try: + hf_hub_file = hf_hub_download(model, file_name, revision=revision) + except (RepositoryNotFoundError, RevisionNotFoundError, + EntryNotFoundError, LocalEntryNotFoundError) as e: + logger.debug("File or repository not found in hf_hub_download", e) + return None + except HfHubHTTPError as e: + logger.warning( + "Cannot connect to Hugging Face Hub. Skipping file " + "download for '%s':", + file_name, + exc_info=e) + return None + file_path = Path(hf_hub_file) + if file_path is not None and file_path.is_file(): with open(file_path) as file: return json.load(file) + return None @@ -328,7 +380,12 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'): """ modules_file_name = "modules.json" - modules_dict = get_hf_file_to_dict(modules_file_name, model, revision) + + modules_dict = None + if file_or_path_exists(model=model, + config_name=modules_file_name, + revision=revision): + modules_dict = get_hf_file_to_dict(modules_file_name, model, revision) if modules_dict is None: return None @@ -382,17 +439,17 @@ def get_sentence_transformer_tokenizer_config(model: str, revision: Optional[str] = 'main' ): """ - Returns the tokenization configuration dictionary for a + Returns the tokenization configuration dictionary for a given Sentence Transformer BERT model. Parameters: - - model (str): The name of the Sentence Transformer + - model (str): The name of the Sentence Transformer BERT model. - revision (str, optional): The revision of the m odel to use. Defaults to 'main'. Returns: - - dict: A dictionary containing the configuration parameters + - dict: A dictionary containing the configuration parameters for the Sentence Transformer BERT model. """ sentence_transformer_config_files = [ @@ -404,20 +461,33 @@ def get_sentence_transformer_tokenizer_config(model: str, "sentence_xlm-roberta_config.json", "sentence_xlnet_config.json", ] - try: - # If model is on HuggingfaceHub, get the repo files - repo_files = list_repo_files(model, revision=revision, token=HF_TOKEN) - except Exception as e: - logger.debug("Error getting repo files", e) - repo_files = [] - encoder_dict = None - for config_name in sentence_transformer_config_files: - if config_name in repo_files or Path(model).exists(): - encoder_dict = get_hf_file_to_dict(config_name, model, revision) + + for config_file in sentence_transformer_config_files: + if try_get_local_file(model=model, + file_name=config_file, + revision=revision) is not None: + encoder_dict = get_hf_file_to_dict(config_file, model, revision) if encoder_dict: break + if not encoder_dict: + try: + # If model is on HuggingfaceHub, get the repo files + repo_files = list_repo_files(model, + revision=revision, + token=HF_TOKEN) + except Exception as e: + logger.debug("Error getting repo files", e) + repo_files = [] + + for config_name in sentence_transformer_config_files: + if config_name in repo_files: + encoder_dict = get_hf_file_to_dict(config_name, model, + revision) + if encoder_dict: + break + if not encoder_dict: return None @@ -519,7 +589,8 @@ def recurse_elems(elem: Any): for key, value in elem.items(): key = config_mapping.get(key, key) config_dict[key] = recurse_elems(value) - return PretrainedConfig(**config_dict) + + return config_dict else: return elem @@ -531,12 +602,30 @@ def recurse_elems(elem: Any): config_dict["max_position_embeddings"] = config_dict.get( "max_position_embeddings", 128_000) + if config_dict.get("quantization") is not None: + quantization = config_dict.get("quantization", {}) + if quantization.get("qformat_weight") == "fp8_e4m3": + # This maps to the FP8 static per-tensor quantization scheme + quantization_config = { + "quant_method": "fp8", + "activation_scheme": "static" + } + else: + raise ValueError( + f"Found unknown quantization='{quantization}' in config") + + config_dict["quantization_config"] = quantization_config + + config_type: Literal["text", + "multimodal"] = "multimodal" if config_dict.get( + "vision_encoder") is not None else "text" + if config_dict.get("moe") is not None: config_dict["architectures"] = ["MixtralForCausalLM"] else: config_dict["architectures"] = ["MistralForCausalLM"] - if config_dict.get("vision_encoder") is not None: + if config_type == "multimodal": multimodal_config = config_dict.pop("vision_encoder") config_dict = { @@ -548,8 +637,16 @@ def recurse_elems(elem: Any): config_dict.update(kwargs) - config = recurse_elems(config_dict) - return config + config_dict = recurse_elems(config_dict) + + # transform to HF config format + if config_type == "multimodal": + config_dict["text_config"] = PretrainedConfig( + **config_dict["text_config"]) + config_dict["vision_config"] = PretrainedConfig( + **config_dict["vision_config"]) + + return PretrainedConfig(**config_dict) def get_hf_image_processor_config( diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py index 99715ba6d0b09..6b2765db94e78 100644 --- a/vllm/transformers_utils/configs/ultravox.py +++ b/vllm/transformers_utils/configs/ultravox.py @@ -37,6 +37,10 @@ class UltravoxConfig(transformers.PretrainedConfig): The LoRA configuration for finetuning the text model. audio_model_lora_config (`LoraConfigSimplified`, *optional*): The LoRA configuration for finetuning the audio model. + projector_ln_mid (`bool`, *optional*, defaults to `False`): + Whether to apply layer normalization at the middle of the + projector or at the end. Versions v0.4.1 and below + use `False`, but v0.5 and above use `True`. """ model_type = "ultravox" @@ -56,6 +60,7 @@ def __init__( projector_act: str = "swiglu", text_model_lora_config: Optional[Dict[str, Any]] = None, audio_model_lora_config: Optional[Dict[str, Any]] = None, + projector_ln_mid: bool = False, **kwargs, ): self.ignore_index = ignore_index @@ -68,6 +73,7 @@ def __init__( self.stack_factor = stack_factor self.norm_init = norm_init self.projector_act = projector_act + self.projector_ln_mid = projector_ln_mid if text_model_id is not None: # Avoid circular import diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py index 8160a35ff2228..a1fa27773fe5c 100644 --- a/vllm/transformers_utils/detokenizer_utils.py +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -74,6 +74,25 @@ def convert_prompt_ids_to_tokens( return new_tokens, prefix_offset, read_offset +def convert_ids_list_to_tokens( + tokenizer: AnyTokenizer, + token_ids: List[int], +) -> List[str]: + """Detokenize the input ids individually. + + Args: + tokenizer: tokenizer used by model under test + token_ids: convert these tokens (Python list form) + + Returns: + Python list of token string representations + + """ + token_str_lst = tokenizer.convert_ids_to_tokens(token_ids) + _replace_none_with_empty(token_str_lst) # type: ignore + return token_str_lst + + # Based on # https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 # under Apache 2.0 license diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 1550f978ed201..f08923e7401f3 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -88,7 +88,8 @@ def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]: def find_tokenizer_file(files: List[str]): - file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$") + file_pattern = re.compile( + r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$") matched_files = [file for file in files if file_pattern.match(file)] if len(matched_files) > 1: @@ -103,6 +104,42 @@ def find_tokenizer_file(files: List[str]): return matched_files[0] +def make_mistral_chat_completion_request( + messages: List["ChatCompletionMessageParam"], + tools: Optional[List[Dict[str, + Any]]] = None) -> "ChatCompletionRequest": + last_message = cast(Dict[str, Any], messages[-1]) + if last_message["role"] == "assistant": + last_message["prefix"] = True + + last_message = cast(Dict[str, Any], messages[-1]) + if last_message["role"] == "assistant": + last_message["prefix"] = True + + # mistral-common requires AssistantMessage content to be string [1]. + # + # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 + for message in messages: + if message.get("role") == "assistant": + content = message.get("content") + if isinstance(content, list): + content = "\n".join(chunk.get("text") for chunk in content) + message["content"] = content + + # The Mistral client, in comparison to the OpenAI client, requires the + # "parameters" dict to be present, even if it's empty. + if tools: + for function in [ + tool["function"] for tool in tools + if tool["type"] == "function" + ]: + function.setdefault("parameters", {}) + + from mistral_common.protocol.instruct.request import ChatCompletionRequest + return ChatCompletionRequest(messages=messages, + tools=tools) # type: ignore[type-var] + + class MistralTokenizer: def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: @@ -282,17 +319,10 @@ def encode(self, prompt: str) -> List[int]: def apply_chat_template(self, messages: List["ChatCompletionMessageParam"], - tools: Optional[Dict[str, Any]] = None, + tools: Optional[List[Dict[str, Any]]] = None, **kwargs) -> List[int]: - last_message = cast(Dict[str, Any], messages[-1]) - if last_message["role"] == "assistant": - last_message["prefix"] = True - - from mistral_common.protocol.instruct.request import ( - ChatCompletionRequest) - request = ChatCompletionRequest(messages=messages, - tools=tools) # type: ignore[type-var] + request = make_mistral_chat_completion_request(messages, tools) encoded = self.mistral.encode_chat_completion(request) # encode-decode to get clean prompt diff --git a/vllm/utils.py b/vllm/utils.py index 8b92695987573..e168752766661 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2239,34 +2239,13 @@ def import_pynvml(): This causes errors when both of them are installed. Starting from version 12.0, it migrates to a new module named `pynvml_utils` to avoid the conflict. - - TL;DR: if users have pynvml<12.0 installed, it will cause problems. - Otherwise, `import pynvml` will import the correct module. - We take the safest approach here, to manually import the correct - `pynvml.py` module from the `nvidia-ml-py` package. + It is so confusing that many packages in the community use the + unofficial one by mistake, and we have to handle this case. + For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial + one, and it will cause errors, see the issue + https://github.com/vllm-project/vllm/issues/12847 for example. + After all the troubles, we decide to copy the official `pynvml` + module to our codebase, and use it directly. """ - if TYPE_CHECKING: - import pynvml - return pynvml - if "pynvml" in sys.modules: - import pynvml - if pynvml.__file__.endswith("__init__.py"): - # this is pynvml < 12.0 - raise RuntimeError( - "You are using a deprecated `pynvml` package. " - "Please uninstall `pynvml` or upgrade to at least" - " version 12.0. See https://pypi.org/project/pynvml " - "for more information.") - return sys.modules["pynvml"] - import importlib.util - import os - import site - for site_dir in site.getsitepackages(): - pynvml_path = os.path.join(site_dir, "pynvml.py") - if os.path.exists(pynvml_path): - spec = importlib.util.spec_from_file_location( - "pynvml", pynvml_path) - pynvml = importlib.util.module_from_spec(spec) - sys.modules["pynvml"] = pynvml - spec.loader.exec_module(pynvml) - return pynvml + import vllm.third_party.pynvml as pynvml + return pynvml diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 204afc9f4025d..5cb1e2fd26a5c 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -10,7 +10,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION +from vllm.attention.backends.utils import get_flash_attn_version from vllm.logger import init_logger from vllm.utils import cdiv from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -132,6 +132,7 @@ def __init__( "encoder/decoder cross-attention " "are not implemented for " "FlashAttentionImpl") + self.vllm_flash_attn_version = get_flash_attn_version() def forward( self, @@ -205,7 +206,7 @@ def forward( window_size=self.sliding_window, block_table=attn_metadata.block_table, softcap=self.logits_soft_cap, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) return output @@ -227,7 +228,7 @@ def forward( logits_soft_cap=self.logits_soft_cap, block_table=attn_metadata.block_table, common_prefix_len=attn_metadata.common_prefix_len, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) return output diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index de349ec120999..f75d31f542cf7 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -10,6 +10,7 @@ generate_block_hash_extra_keys, hash_block_tokens, hash_request_tokens) +from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus logger = init_logger(__name__) @@ -72,11 +73,34 @@ def __init__( self.req_to_blocks: DefaultDict[str, List[KVCacheBlock]] = defaultdict(list) + # Mapping from request ID to kv block hashes. + # This is to avoid recomputing the block hashes for each call of + # `get_computed_blocks` or `allocate_slots`. + self.req_to_block_hashes: DefaultDict[ + str, List[BlockHashType]] = defaultdict(list) + + self.prefix_cache_stats = PrefixCacheStats() + @property def usage(self) -> float: + """Get the KV cache usage. + + Returns: + The KV cache usage (between 0.0 and 1.0). + """ return 1.0 - (self.free_block_queue.num_free_blocks / self.num_gpu_blocks) + def make_prefix_cache_stats(self) -> PrefixCacheStats: + """Get (and reset) the prefix cache stats. + + Returns: + The current prefix caching stats. + """ + stats = self.prefix_cache_stats + self.prefix_cache_stats = PrefixCacheStats() + return stats + def get_computed_blocks( self, request: Request) -> Tuple[List[KVCacheBlock], int]: """Get the computed (cached) blocks for the request. @@ -97,11 +121,11 @@ def get_computed_blocks( computed_blocks = [] # The block hashes for the request may already be computed - # if the request was preempted and resumed. - if not request.kv_block_hashes: - request.set_kv_block_hashes( - hash_request_tokens(self.block_size, request)) - block_hashes = request.kv_block_hashes + # if the scheduler has tried to schedule the request before. + block_hashes = self.req_to_block_hashes[request.request_id] + if not block_hashes: + block_hashes = hash_request_tokens(self.block_size, request) + self.req_to_block_hashes[request.request_id] = block_hashes for block_hash in block_hashes: # block_hashes is a chain of block hashes. If a block hash is not @@ -112,6 +136,10 @@ def get_computed_blocks( else: break + self.prefix_cache_stats.requests += 1 + self.prefix_cache_stats.queries += len(block_hashes) + self.prefix_cache_stats.hits += len(computed_blocks) + # NOTE(woosuk): Since incomplete blocks are not eligible for # sharing, `num_computed_tokens` is always a multiple of # `block_size`. @@ -199,8 +227,6 @@ def allocate_slots( # Should not exceed the maximum number of blocks per request. # This is especially because the block table has the shape # [..., max_num_blocks_per_req]. - # TODO(woosuk): Check and reject requests if - # num_prompt_tokens + max_tokens > max_model_len. self.max_num_blocks_per_req - len(req_blocks), ) assert num_new_blocks > 0 @@ -276,6 +302,8 @@ def reset_prefix_cache(self) -> bool: for block in self.block_pool: block.reset_hash() + self.prefix_cache_stats.reset = True + logger.info("Successfully reset prefix cache") return True @@ -299,9 +327,7 @@ def get_num_common_prefix_blocks( While all scheduled requests must be in the RUNNING state, the inverse is not necessarily true. There may be RUNNING requests that are not - scheduled in the current step. As of 1/1/2025, the scheduler does not - allow this case, but it is possible in the future, as we allow more - flexible scheduling. + scheduled in the current step. This can result in an edge case where the number of common prefix blocks is 0, even though all scheduled requests share a common prefix. This @@ -437,7 +463,8 @@ def _cache_full_blocks( full_blocks: The list of blocks to update hash metadata. prev_block: The previous block in the chain. """ - num_cached_block_hashes = len(request.kv_block_hashes) + block_hashes = self.req_to_block_hashes[request.request_id] + num_cached_block_hashes = len(block_hashes) # Update the new blocks with the block hashes through the chain. prev_block_hash_value = None @@ -470,7 +497,7 @@ def _cache_full_blocks( # this request (either the prompt tokens or the previously # generated tokens with preemption). In this case we simply # reuse the block hash. - block_hash = request.kv_block_hashes[blk_idx] + block_hash = block_hashes[blk_idx] else: # Otherwise compute the block hash and cache it in the request # in case it will be preempted in the future. @@ -492,9 +519,17 @@ def _cache_full_blocks( # Compute the hash of the current block. block_hash = hash_block_tokens(prev_block_hash_value, block_tokens, extra_keys) - request.append_kv_block_hashes(block_hash) + block_hashes.append(block_hash) # Update and added the full block to the cache. blk.block_hash = block_hash self.cached_block_hash_to_block[block_hash][blk.block_id] = blk prev_block_hash_value = block_hash.hash_value + + def free_block_hashes(self, request: Request) -> None: + """Discard the block hashes for the request. + + NOTE: Unlike `free`, this method should be called only when the request + is finished, not when it is preempted. + """ + self.req_to_block_hashes.pop(request.request_id, None) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6888f1a3e1823..bddb482d29167 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """KV-Cache Utilities.""" +from collections import deque from collections.abc import Sequence from dataclasses import dataclass from typing import Any, List, NamedTuple, Optional, Tuple @@ -8,6 +9,7 @@ from vllm.logger import init_logger from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, KVCacheTensor) +from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request logger = init_logger(__name__) @@ -28,6 +30,68 @@ class BlockHashType(NamedTuple): extra_keys: Optional[Any] = None +class PrefixCachingMetrics: + """Metrics for prefix caching with a hit rate of the most recent N requests. + + Args: + interval: The number of the most recent requests to aggregate. + Defaults to 1000. + """ + + def __init__(self, interval: int = 1000): + self.interval = interval + # The current aggregated values. + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + # A deque of (requests, queries, hits) for the most recent requests. + self.query_queue: deque[Tuple[int, int, int]] = deque() + + def observe(self, stats: PrefixCacheStats): + """Observe the prefix caching for a set of requests. + + This function is called with information gathered when new requests + are being scheduled and are looking for computed blocks. + + When there are more than `interval` requests, the oldest set of + requestsare removed from the metrics. + + Args: + stats: The prefix cache stats. + """ + # reset_prefix_cache was invoked before the current update. + # Reset the metrics before aggregating the current stats. + if stats.reset: + self.reset() + + # Update the metrics. + self.query_queue.append((stats.requests, stats.queries, stats.hits)) + self.aggregated_requests += stats.requests + self.aggregated_query_total += stats.queries + self.aggregated_query_hit += stats.hits + + # Remove the oldest stats if the number of requests exceeds. + if self.aggregated_requests > self.interval: + old_requests, old_queries, old_hits = self.query_queue.popleft() + self.aggregated_requests -= old_requests + self.aggregated_query_total -= old_queries + self.aggregated_query_hit -= old_hits + + def reset(self): + """Reset the metrics.""" + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + self.query_queue.clear() + + @property + def hit_rate(self) -> float: + """Calculate the hit rate for the past N requests.""" + if self.aggregated_query_total == 0: + return 0.0 + return self.aggregated_query_hit / self.aggregated_query_total + + @dataclass class KVCacheBlock: """KV-cache block metadata.""" diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 6c44fec6439e7..985fcf01bb216 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -1,26 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 from collections import deque -from dataclasses import dataclass -from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set, - Tuple, Union) +from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.sampling_params import SamplingParams from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData, + SchedulerOutput) from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus -if TYPE_CHECKING: - from vllm.multimodal import MultiModalKwargs - from vllm.multimodal.base import PlaceholderRange - logger = init_logger(__name__) @@ -437,6 +431,8 @@ def update_from_output( ) -> EngineCoreOutputs: # NOTE(woosuk): This method doesn't consider speculative decoding. sampled_token_ids = model_runner_output.sampled_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: List[Request] = [] outputs: List[EngineCoreOutput] = [] @@ -471,6 +467,13 @@ def update_from_output( self.encoder_cache_manager.free_encoder_input( request, input_id) + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + + stopped = False + new_logprobs = None + new_token_ids = None + if request.num_computed_tokens == request.num_tokens: req_index = model_runner_output.req_id_to_index[req_id] # NOTE(woosuk): Currently, we assume that each request @@ -486,20 +489,30 @@ def update_from_output( if stopped: self._free_request(request) + # Extract sample logprobs if needed. + if request.sampling_params.logprobs is not None: + assert logprobs is not None + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) + + new_token_ids = request.output_token_ids[-num_new_tokens:] + + # Transmit partial if chunked prefill & prompt logprobs is enabled + if new_token_ids or prompt_logprobs_tensors is not None: # Add EngineCoreOutput for this Request. - output = EngineCoreOutput( - request_id=req_id, - new_token_ids=request.output_token_ids[-num_new_tokens:], - finished=request.is_finished(), - finish_reason=request.get_finished_reason(), - stop_reason=request.stop_reason) - outputs.append(output) - - # Breakout of the loop. - if stopped: - continue + outputs.append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids or [], + finish_reason=request.get_finished_reason(), + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + stop_reason=request.stop_reason)) + + if not stopped: + new_running.append(request) - new_running.append(request) self.running = new_running return EngineCoreOutputs( outputs=outputs, @@ -560,6 +573,7 @@ def finish_requests( def _free_request(self, request: Request) -> None: assert request.is_finished() self.kv_cache_manager.free(request) + self.kv_cache_manager.free_block_hashes(request) self.encoder_cache_manager.free(request) self._cached_reqs_data.pop(request.request_id, None) del self.requests[request.request_id] @@ -579,81 +593,5 @@ def make_stats(self) -> SchedulerStats: num_running_reqs=len(self.running), num_waiting_reqs=len(self.waiting), gpu_cache_usage=self.kv_cache_manager.usage, + prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(), ) - - -@dataclass -class NewRequestData: - - req_id: str - prompt_token_ids: List[int] - prompt: Optional[str] - mm_inputs: List["MultiModalKwargs"] - mm_hashes: List[str] - mm_positions: List["PlaceholderRange"] - sampling_params: SamplingParams - block_ids: List[int] - num_computed_tokens: int - lora_request: Optional[LoRARequest] - - @classmethod - def from_request( - cls, - request: Request, - block_ids: List[int], - num_computed_tokens: int, - ) -> "NewRequestData": - return cls( - req_id=request.request_id, - prompt_token_ids=request.prompt_token_ids, - prompt=request.prompt, - mm_inputs=request.mm_inputs, - mm_hashes=request.mm_hashes, - mm_positions=request.mm_positions, - sampling_params=request.sampling_params, - block_ids=block_ids, - num_computed_tokens=num_computed_tokens, - lora_request=request.lora_request, - ) - - -@dataclass -class CachedRequestData: - - req_id: str - # If resumed_from_preemption is False, new_block_ids will be appended to - # the request's block IDs. If True, new_block_ids will be used as the - # request's block IDs instead of appending to the existing block IDs. - resumed_from_preemption: bool - new_block_ids: List[int] - num_computed_tokens: int - - @classmethod - def from_request( - cls, - request: Request, - resumed_from_preemption: bool, - new_block_ids: List[int], - num_computed_tokens: int, - ) -> "CachedRequestData": - return cls( - req_id=request.request_id, - resumed_from_preemption=resumed_from_preemption, - new_block_ids=new_block_ids, - num_computed_tokens=num_computed_tokens, - ) - - -@dataclass -class SchedulerOutput: - - scheduled_new_reqs: List[NewRequestData] - scheduled_cached_reqs: List[CachedRequestData] - - num_scheduled_tokens: Dict[str, int] - total_num_scheduled_tokens: int - scheduled_encoder_inputs: Dict[str, List[int]] - num_common_prefix_blocks: int - - finished_req_ids: Set[str] - free_encoder_input_ids: List[Tuple[str, int]] diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py new file mode 100644 index 0000000000000..990b3dd0ed780 --- /dev/null +++ b/vllm/v1/core/scheduler_output.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple + +if TYPE_CHECKING: + from vllm.lora.request import LoRARequest + from vllm.multimodal import MultiModalKwargs + from vllm.multimodal.base import PlaceholderRange + from vllm.sampling_params import SamplingParams + from vllm.v1.request import Request + + +@dataclass +class NewRequestData: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + mm_inputs: List["MultiModalKwargs"] + mm_hashes: List[str] + mm_positions: List["PlaceholderRange"] + sampling_params: "SamplingParams" + block_ids: List[int] + num_computed_tokens: int + lora_request: Optional["LoRARequest"] + + @classmethod + def from_request( + cls, + request: "Request", + block_ids: List[int], + num_computed_tokens: int, + ) -> "NewRequestData": + return cls( + req_id=request.request_id, + prompt_token_ids=request.prompt_token_ids, + prompt=request.prompt, + mm_inputs=request.mm_inputs, + mm_hashes=request.mm_hashes, + mm_positions=request.mm_positions, + sampling_params=request.sampling_params, + block_ids=block_ids, + num_computed_tokens=num_computed_tokens, + lora_request=request.lora_request, + ) + + +@dataclass +class CachedRequestData: + + req_id: str + # If resumed_from_preemption is False, new_block_ids will be appended to + # the request's block IDs. If True, new_block_ids will be used as the + # request's block IDs instead of appending to the existing block IDs. + resumed_from_preemption: bool + new_block_ids: List[int] + num_computed_tokens: int + + @classmethod + def from_request( + cls, + request: "Request", + resumed_from_preemption: bool, + new_block_ids: List[int], + num_computed_tokens: int, + ) -> "CachedRequestData": + return cls( + req_id=request.request_id, + resumed_from_preemption=resumed_from_preemption, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) + + +@dataclass +class SchedulerOutput: + + # List of the requests that are scheduled for the first time. + # We cache the request's data in each worker process, so that we don't + # need to re-send it every scheduling step. + scheduled_new_reqs: List[NewRequestData] + # List of the requests that have been scheduled before. + # Since the request's data is already cached in the worker processes, + # we only send the diff to minimize the communication cost. + scheduled_cached_reqs: List[CachedRequestData] + + # req_id -> num_scheduled_tokens + # Number of tokens scheduled for each request. + num_scheduled_tokens: Dict[str, int] + # Total number of tokens scheduled for all requests. + # Equal to sum(num_scheduled_tokens.values()) + total_num_scheduled_tokens: int + # req_id -> encoder input indices that need processing. + # E.g., if a request has [0, 1], it could mean the vision encoder needs + # to process that the request's 0-th and 1-th images in the current step. + scheduled_encoder_inputs: Dict[str, List[int]] + # Number of common prefix blocks for all requests. + # This can be used for cascade attention. + num_common_prefix_blocks: int + + # Request IDs that are finished in between the previous and the current + # steps. This is used to notify the workers about the finished requests + # so that they can free the cached states for those requests. + finished_req_ids: Set[str] + # List of (req_id, encoder_input_index) tuples. + # Used to free the encoder cache. + free_encoder_input_ids: List[Tuple[str, int]] diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index d5933cac50c20..30e1185019d9d 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -1,18 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 import enum -from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Optional, Union import msgspec +from vllm.lora.request import LoRARequest +from vllm.multimodal import MultiModalKwargs +from vllm.multimodal.inputs import PlaceholderRange +from vllm.sampling_params import SamplingParams from vllm.v1.metrics.stats import SchedulerStats - -if TYPE_CHECKING: - from vllm.lora.request import LoRARequest - from vllm.multimodal import MultiModalKwargs - from vllm.multimodal.inputs import PlaceholderRange - from vllm.sampling_params import SamplingParams +from vllm.v1.outputs import LogprobsLists, LogprobsTensors # These are possible values of RequestOutput.finish_reason, # so form part of the external API. @@ -38,8 +36,11 @@ def __str__(self): return FINISH_REASON_STRINGS[self.value] -@dataclass -class EngineCoreRequest: +class EngineCoreRequest( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False): # type: ignore[call-arg] # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, # but this object is currently not playing well with msgspec @@ -50,13 +51,13 @@ class EngineCoreRequest: # Detokenizer, but set to None when it is added to EngineCoreClient. prompt: Optional[str] prompt_token_ids: List[int] - mm_inputs: Optional[List[Optional["MultiModalKwargs"]]] + mm_inputs: Optional[List[Optional[MultiModalKwargs]]] mm_hashes: Optional[List[str]] - mm_placeholders: Optional[List["PlaceholderRange"]] - sampling_params: "SamplingParams" + mm_placeholders: Optional[List[PlaceholderRange]] + sampling_params: SamplingParams eos_token_id: Optional[int] arrival_time: float - lora_request: Optional["LoRARequest"] + lora_request: Optional[LoRARequest] class EngineCoreOutput( @@ -67,10 +68,17 @@ class EngineCoreOutput( request_id: str new_token_ids: List[int] - finished: bool + + new_logprobs: Optional[LogprobsLists] = None + new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None + finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None + @property + def finished(self) -> bool: + return self.finish_reason is not None + class EngineCoreOutputs( msgspec.Struct, @@ -86,16 +94,6 @@ class EngineCoreOutputs( scheduler_stats: SchedulerStats -@dataclass -class EngineCoreProfile: - is_start: bool - - -@dataclass -class EngineCoreResetPrefixCache: - pass - - class EngineCoreRequestType(enum.Enum): """ Request types defined as hex byte strings, so it can be sent over sockets @@ -105,7 +103,3 @@ class EngineCoreRequestType(enum.Enum): ABORT = b'\x01' PROFILE = b'\x02' RESET_PREFIX_CACHE = b'\x03' - - -EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, - EngineCoreResetPrefixCache, List[str]] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 29a9ac1868f27..c90667ba0331e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,17 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 -import pickle import queue import signal import threading import time from multiprocessing.connection import Connection -from typing import List, Tuple, Type +from typing import Any, List, Tuple, Type import psutil import zmq import zmq.asyncio -from msgspec import msgpack from vllm.config import VllmConfig from vllm.logger import init_logger @@ -20,13 +18,12 @@ from vllm.utils import get_exception_traceback, zmq_socket_ctx from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.core.scheduler import Scheduler -from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, - EngineCoreRequest, EngineCoreRequestType, - EngineCoreRequestUnion, EngineCoreResetPrefixCache) +from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, + EngineCoreRequestType) from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.executor.abstract import Executor from vllm.v1.request import Request, RequestStatus -from vllm.v1.serial_utils import PickleEncoder +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -162,7 +159,8 @@ def __init__( # and to overlap some serialization/deserialization with the # model forward pass. # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue() + self.input_queue: queue.Queue[Tuple[EngineCoreRequestType, + Any]] = queue.Queue() self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() threading.Thread(target=self.process_input_socket, args=(input_path, ), @@ -224,7 +222,7 @@ def run_busy_loop(self): while True: try: req = self.input_queue.get(timeout=POLLING_TIMEOUT_S) - self._handle_client_request(req) + self._handle_client_request(*req) break except queue.Empty: logger.debug("EngineCore busy loop waiting.") @@ -234,10 +232,10 @@ def run_busy_loop(self): except BaseException: raise - # 2) Handle any new client requests (Abort or Add). + # 2) Handle any new client requests. while not self.input_queue.empty(): req = self.input_queue.get_nowait() - self._handle_client_request(req) + self._handle_client_request(*req) # 3) Step the engine core. outputs = self.step() @@ -245,54 +243,46 @@ def run_busy_loop(self): # 5) Put EngineCoreOutputs into the output queue. self.output_queue.put_nowait(outputs) - def _handle_client_request(self, request: EngineCoreRequestUnion) -> None: - """Handle EngineCoreRequest or EngineCoreABORT from Client.""" + def _handle_client_request(self, request_type: EngineCoreRequestType, + request: Any) -> None: + """Dispatch request from client.""" - if isinstance(request, EngineCoreRequest): + if request_type == EngineCoreRequestType.ADD: self.add_request(request) - elif isinstance(request, EngineCoreProfile): - self.model_executor.profile(request.is_start) - elif isinstance(request, EngineCoreResetPrefixCache): - self.reset_prefix_cache() - else: - # TODO: make an EngineCoreAbort wrapper - assert isinstance(request, list) + elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) + elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE: + self.reset_prefix_cache() + elif request_type == EngineCoreRequestType.PROFILE: + self.model_executor.profile(request) def process_input_socket(self, input_path: str): """Input socket IO thread.""" # Msgpack serialization decoding. - decoder_add_req = PickleEncoder() - decoder_abort_req = PickleEncoder() + add_request_decoder = MsgpackDecoder(EngineCoreRequest) + generic_decoder = MsgpackDecoder() with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket: while True: # (RequestType, RequestData) type_frame, data_frame = socket.recv_multipart(copy=False) - request_type = type_frame.buffer - request_data = data_frame.buffer + request_type = EngineCoreRequestType(bytes(type_frame.buffer)) # Deserialize the request data. - if request_type == EngineCoreRequestType.ADD.value: - request = decoder_add_req.decode(request_data) - elif request_type == EngineCoreRequestType.ABORT.value: - request = decoder_abort_req.decode(request_data) - elif request_type in ( - EngineCoreRequestType.PROFILE.value, - EngineCoreRequestType.RESET_PREFIX_CACHE.value): - request = pickle.loads(request_data) - else: - raise ValueError(f"Unknown RequestType: {request_type}") + decoder = add_request_decoder if ( + request_type + == EngineCoreRequestType.ADD) else generic_decoder + request = decoder.decode(data_frame.buffer) # Push to input queue for core busy loop. - self.input_queue.put_nowait(request) + self.input_queue.put_nowait((request_type, request)) def process_output_socket(self, output_path: str): """Output socket IO thread.""" # Msgpack serialization encoding. - encoder = msgpack.Encoder() + encoder = MsgpackEncoder() # Reuse send buffer. buffer = bytearray() diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 247380ef7cfed..2d7d6b42ced52 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -5,9 +5,8 @@ import signal import weakref from abc import ABC, abstractmethod -from typing import List, Optional, Type +from typing import Any, List, Optional, Type -import msgspec import zmq import zmq.asyncio @@ -15,12 +14,11 @@ from vllm.logger import init_logger from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree, make_zmq_socket) -from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, - EngineCoreRequest, EngineCoreRequestType, - EngineCoreRequestUnion, EngineCoreResetPrefixCache) +from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, + EngineCoreRequestType) from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.executor.abstract import Executor -from vllm.v1.serial_utils import PickleEncoder +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.utils import BackgroundProcHandle logger = init_logger(__name__) @@ -162,8 +160,8 @@ def sigusr1_handler(signum, frame): signal.signal(signal.SIGUSR1, sigusr1_handler) # Serialization setup. - self.encoder = PickleEncoder() - self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs) + self.encoder = MsgpackEncoder() + self.decoder = MsgpackDecoder(EngineCoreOutputs) # ZMQ setup. self.ctx = ( @@ -221,7 +219,7 @@ def get_output(self) -> EngineCoreOutputs: return self.decoder.decode(frame.buffer) def _send_input(self, request_type: EngineCoreRequestType, - request: EngineCoreRequestUnion) -> None: + request: Any) -> None: # (RequestType, SerializedRequest) msg = (request_type.value, self.encoder.encode(request)) @@ -238,12 +236,10 @@ def abort_requests(self, request_ids: List[str]) -> None: self._send_input(EngineCoreRequestType.ABORT, request_ids) def profile(self, is_start: bool = True) -> None: - self._send_input(EngineCoreRequestType.PROFILE, - EngineCoreProfile(is_start)) + self._send_input(EngineCoreRequestType.PROFILE, is_start) def reset_prefix_cache(self) -> None: - self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, - EngineCoreResetPrefixCache()) + self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None) class AsyncMPClient(MPClient): @@ -278,7 +274,7 @@ async def process_outputs_socket(): return self.decoder.decode(await self.outputs_queue.get()) async def _send_input(self, request_type: EngineCoreRequestType, - request: EngineCoreRequestUnion) -> None: + request: Any) -> None: msg = (request_type.value, self.encoder.encode(request)) await self.input_socket.send_multipart(msg, copy=False) @@ -294,9 +290,7 @@ async def abort_requests_async(self, request_ids: List[str]) -> None: await self._send_input(EngineCoreRequestType.ABORT, request_ids) async def profile_async(self, is_start: bool = True) -> None: - await self._send_input(EngineCoreRequestType.PROFILE, - EngineCoreProfile(is_start)) + await self._send_input(EngineCoreRequestType.PROFILE, is_start) async def reset_prefix_cache_async(self) -> None: - await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, - EngineCoreResetPrefixCache()) + await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None) diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 861fcb012c34e..629da06f4925b 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -1,27 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List, Optional from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger -from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) -from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason +from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) -@dataclass -class DetokenizerOutput: - output_text: str - token_ids: List[int] - finished: bool - finish_reason: Optional[FinishReason] = None - stop_reason: Union[int, str, None] = None - - @dataclass class IncrementalDetokenizer: @@ -42,7 +32,6 @@ class IncrementalDetokenizer: # Parameters for detokenization skip_special_tokens: bool spaces_between_special_tokens: bool - output_kind: RequestOutputKind # Tokenizer for this request tokenizer: AnyTokenizer @@ -90,25 +79,19 @@ def from_new_request( skip_special_tokens=request.sampling_params.skip_special_tokens, spaces_between_special_tokens=request.sampling_params. spaces_between_special_tokens, - output_kind=request.sampling_params.output_kind, prompt_len=len(request.prompt_token_ids), tokenizer=tokenizer, stop_buffer_length=stop_buffer_length, ) - def update_from_output( - self, - output: EngineCoreOutput, - ) -> Optional[DetokenizerOutput]: + def update(self, new_token_ids: List[int]) -> Optional[str]: """ Update RequestState for the request_id by: 1) Detokenize the new token ids incrementally. - 2) Update the RequestOutput with the new text. - """ + 2) Evaluate stop criteria. - new_token_ids = output.new_token_ids - finish_reason = output.finish_reason - stop_reason = output.stop_reason + Return matched stop string or None. + """ # 1) Detokenize the new token ids incrementally. # TODO(woosuk): This method becomes very inefficient when the number of @@ -131,11 +114,13 @@ def update_from_output( self.tokens.extend(new_tokens) self.prefix_offset = prefix_offset self.read_offset = read_offset - self.output_text += new_decoded_token_text decoded_text += new_decoded_token_text + self.output_text += decoded_text + # 2) Evaluate stop criteria. + stop_string = None if self.stop: stop = StopChecker.check_stop_strings( output_text=self.output_text, @@ -144,28 +129,13 @@ def update_from_output( include_in_output=self.include_stop_str_in_output, ) if stop is not None: - stop_str, truncate_to = stop + stop_string, truncate_to = stop if truncate_to != -1: self.output_text = self.output_text[:truncate_to] - finish_reason = FinishReason.STOP - stop_reason = stop_str - - # TODO: handle stop_token_ids here too? - - # 3) Update the RequestOutput object with the new text. - finished = finish_reason is not None - if self.output_kind == RequestOutputKind.FINAL_ONLY \ - and not finished: - return None - - delta = self.output_kind == RequestOutputKind.DELTA - output_text = self._get_next_output_text(finished, delta) - token_ids = new_token_ids if delta else self.output_token_ids - return DetokenizerOutput(output_text, token_ids, finished, - finish_reason, stop_reason) + return stop_string - def _get_next_output_text(self, finished: bool, delta: bool) -> str: + def get_next_output_text(self, finished: bool, delta: bool) -> str: """If delta is True, only new text since the last call to this method is returned""" diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index e0452bcad7ba7..3ef5a9706063a 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -45,6 +45,7 @@ def __init__( multiprocess_mode: bool = False, ) -> None: self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py new file mode 100644 index 0000000000000..4622cafa4a028 --- /dev/null +++ b/vllm/v1/engine/logprobs.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from dataclasses import dataclass +from typing import Dict, List, Optional + +from vllm.logger import init_logger +from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs +from vllm.transformers_utils.detokenizer_utils import ( + AnyTokenizer, convert_ids_list_to_tokens) +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest +from vllm.v1.outputs import LogprobsLists, LogprobsTensors + +logger = init_logger(__name__) + + +@dataclass +class LogprobsProcessor: + + # Tokenizer for this request + tokenizer: AnyTokenizer + + # Logprobs for this request + logprobs: Optional[SampleLogprobs] + prompt_logprobs: Optional[PromptLogprobs] + cumulative_logprob: Optional[float] + num_logprobs: Optional[int] + num_prompt_logprobs: Optional[int] + + @classmethod + def from_new_request( + cls, + tokenizer: AnyTokenizer, + request: EngineCoreRequest, + ) -> "LogprobsProcessor": + num_logprobs = request.sampling_params.logprobs + num_prompt_logprobs = request.sampling_params.prompt_logprobs + return cls( + tokenizer=tokenizer, + cumulative_logprob=(None if num_logprobs is None else 0.), + logprobs=(None if num_logprobs is None else []), + # NOTE: logprob of first prompt token is None. + prompt_logprobs=(None if num_prompt_logprobs is None else [None]), + num_prompt_logprobs=num_prompt_logprobs, + num_logprobs=num_logprobs, + ) + + def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None: + """Update with sample logprobs from EngineCore. + + Outer lists are only of len > 1 if EngineCore made + >1 tokens in prior step (e.g. in spec decoding). + + Args: + logprobs_lists: the lists of logprob tokens, logprobs, and ranks. + + """ + + assert self.num_logprobs is not None + assert self.logprobs is not None + assert self.cumulative_logprob is not None + + token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists + + for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, + token_ids_lst): + + # Detokenize (non-incrementally). + decoded_tokens = convert_ids_list_to_tokens( + self.tokenizer, token_ids) + + # Sampler puts the sampled logprob in first. + sampled_token_logprob = logprobs[0] + self.cumulative_logprob += sampled_token_logprob + + # Update with the Logprob dictionary for this pos. + self.logprobs.append( + self._make_logprob_dict( + logprobs, + token_ids, + decoded_tokens, + rank, + self.num_logprobs, + )) + + def _update_prompt_logprobs( + self, + prompt_logprobs_tensors: LogprobsTensors, + ) -> None: + """Update with prompt logprobs from EngineCore. + + Args: + prompt_logprobs_tensors: tuple containing the prompt logprobs + tensors. + + """ + + # Prompt logprobs are enabled. + assert self.num_prompt_logprobs is not None + assert self.prompt_logprobs is not None + + token_ids, logprobs, ranks = prompt_logprobs_tensors + + # Detokenize non-incrementally. + # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] + decoded_tokens = convert_ids_list_to_tokens( + self.tokenizer, + token_ids.flatten().tolist()) + + # Recover shapes. + num_prompt_tokens, num_logprobs = logprobs.shape + + # Pythonize the torch tensors. + # TODO(rob): experiment with doing this in EngineCore? + prompt_token_ranks = ranks.tolist() + prompt_logprobs = logprobs.tolist() + token_ids = token_ids.tolist() + + # Make Logprob for each position. + for pos in range(num_prompt_tokens): + # Handle flattening. + offset = pos * num_logprobs + offset_end = offset + num_logprobs + decoded_tokens_for_pos = decoded_tokens[offset:offset_end] + + # Update with the Logprob dictionary for this pos. + self.prompt_logprobs.append( + self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos], + decoded_tokens_for_pos, + prompt_token_ranks[pos], + self.num_prompt_logprobs)) + + def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: + """Pop and return all request prompt logprobs + + The logprobs processor aggregates prompt chunk logprobs + over one or more prefill chunks. This method returns + all prompt logprobs at once and then forgets them. + Ensures correct RequestOutputKind.DELTA semantics + wherein all prompt logprobs are returned at once at + the end of prefill. + + Returns: + None if prompt logprobs are disabled for this request. + List of all prompt logprobs, otherwise. + """ + plp = self.prompt_logprobs + if plp: + self.prompt_logprobs = [] + return plp + + @staticmethod + def _make_logprob_dict( + logprobs: List[float], + logprob_token_ids: List[int], + decoded_tokens: List[str], + rank: int, + num_logprobs: int, + ) -> Dict[int, Logprob]: + """Make a Logprob dictionary for a position. + + Args: + logprobs: list of log probabilities + logprob_token_ids: list of top token ids + decoded_tokens: list of decoded top tokens + rank: rank of the sampled token + num_logprobs: number of logprobs requested + by the user (in addition to sampled logprob) + + Returns: + Dict[token id, Logprob] + """ + + # We do not need a special case for the sampled token + # being in the topk, since inserting duplicated data + # into a dictionary twice is the same as doing it once. + topk_ranks = range(1, num_logprobs + 1) + ranks = itertools.chain((rank, ), topk_ranks) + + return { + token_id: Logprob( + logprob=logprob, + rank=rank, + decoded_token=token, + ) + for token_id, logprob, rank, token in zip( + logprob_token_ids, logprobs, ranks, decoded_tokens) + } + + def update_from_output(self, output: EngineCoreOutput) -> None: + if output.new_logprobs is not None: + self._update_sample_logprobs(output.new_logprobs) + if output.new_prompt_logprobs_tensors is not None: + self._update_prompt_logprobs(output.new_prompt_logprobs_tensors) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 9473666914717..5dbf530caa17a 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -5,11 +5,12 @@ from typing import Dict, List, Optional from vllm.outputs import RequestOutput -from vllm.transformers_utils.detokenizer_utils import AnyTokenizer +from vllm.sampling_params import RequestOutputKind +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest -from vllm.v1.engine.detokenizer import (DetokenizerOutput, - IncrementalDetokenizer) +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason +from vllm.v1.engine.detokenizer import IncrementalDetokenizer +from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.metrics.stats import IterationStats, RequestStateStats @@ -26,16 +27,20 @@ class RequestState: def __init__( self, request_id: str, + output_kind: RequestOutputKind, prompt: Optional[str], prompt_token_ids: List[int], + logprobs_processor: LogprobsProcessor, detokenizer: IncrementalDetokenizer, arrival_time: float, queue: Optional[asyncio.Queue[RequestOutput]], ): self.request_id = request_id + self.output_kind = output_kind self.prompt = prompt self.prompt_token_ids = prompt_token_ids self.prompt_len = len(prompt_token_ids) + self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer self.is_prefilling = True self.queue = queue @@ -51,8 +56,13 @@ def from_new_request( ) -> "RequestState": return cls( request_id=request.request_id, + output_kind=request.sampling_params.output_kind, prompt=request.prompt, prompt_token_ids=request.prompt_token_ids, + logprobs_processor=LogprobsProcessor.from_new_request( + tokenizer=tokenizer, + request=request, + ), detokenizer=IncrementalDetokenizer.from_new_request( tokenizer=tokenizer, request=request, @@ -127,13 +137,8 @@ def process_outputs( batch to ensure system overheads are minimized. This is the only function that should loop over EngineCoreOutputs. - If you need to touch every element of the batch, implement a - method called XXXClass.update_from_output() to be called - within the loop below. For examples, see: - * IterationStats.update_from_output() - * Detokenizer.update_from_output() - - TODO(rob): add Protocol makes update_from_output explicit. + If you need to touch every element of the batch, do it from + within the loop below. ********************************************************** """ @@ -154,17 +159,37 @@ def process_outputs( req_state.is_prefilling, req_state.prompt_len, req_state.stats) - req_state.is_prefilling = False - - # 2) Detokenize the token ids into text. - detokenizer_output = req_state.detokenizer.update_from_output( - engine_core_output) - - # 3) Create and handle RequestOutput objects. - if detokenizer_output is not None: - request_output = self._make_request_output( - req_state, detokenizer_output) + new_token_ids = engine_core_output.new_token_ids + finish_reason = engine_core_output.finish_reason + + # TODO(andy): prompt logprobs + chunked prefill can + # result in engine core returning an output for a + # partial prefill (in order to send back partial + # prompt logprobs.) This breaks the invariant that + # process_outputs is only operating on engine core + # outputs associated with non-partial completions. + # Currently this is handled by having `is_prefilling` + # check for new decoded tokens, indicating that + # the completion is not partial. + # + # Follow up will aggregate partial prompt logprobs + # in the EngineCore. + req_state.is_prefilling = not new_token_ids + + # 2) Detokenize the token ids into text and check for stop + # strings. + stop_reason = req_state.detokenizer.update(new_token_ids) + if stop_reason: + finish_reason = FinishReason.STOP + + # 3) Compute sample and prompt logprobs for request, + # if required. + req_state.logprobs_processor.update_from_output(engine_core_output) + + # 4) Create and handle RequestOutput objects. + if request_output := self._make_request_output( + req_state, new_token_ids, finish_reason, stop_reason): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put_nowait(request_output) @@ -174,18 +199,16 @@ def process_outputs( # Free completed requests. if request_output.finished: - assert detokenizer_output.finish_reason is not None - self.request_states.pop(req_id) if not engine_core_output.finished: # If req not finished in EngineCore, but Detokenizer # detected stop string, abort needed in EngineCore. reqs_to_abort.append(req_id) - # Track per-request stats + # Track per-request stats. + assert finish_reason is not None iteration_stats.update_from_finished_request( - detokenizer_output.finish_reason, request_output, - req_state.stats) + finish_reason, request_output, req_state.stats) return OutputProcessorOutput( request_outputs=request_outputs, @@ -196,20 +219,47 @@ def process_outputs( @staticmethod def _make_request_output( request_state: RequestState, - detokenizer_output: DetokenizerOutput, - ) -> RequestOutput: + new_token_ids: List[int], + finish_reason: Optional[FinishReason], + stop_reason: Optional[str], + ) -> Optional[RequestOutput]: + + finished = finish_reason is not None + output_kind = request_state.output_kind + # In follow up, we will switch to invariant where EngineCore + # does not stream partial prefills. + if not finished and (request_state.is_prefilling + or output_kind == RequestOutputKind.FINAL_ONLY): + # Only the final output is required in FINAL_ONLY mode. + return None + + detokenizer = request_state.detokenizer + logprobs_processor = request_state.logprobs_processor + + delta = output_kind == RequestOutputKind.DELTA + logprobs = logprobs_processor.logprobs + if delta: + if logprobs: + logprobs = logprobs[-len(new_token_ids):] + # Side effect: logprobs processor forgets prompt logprobs + prompt_logprobs = logprobs_processor.pop_prompt_logprobs() + else: + prompt_logprobs = logprobs_processor.prompt_logprobs + request_output = RequestOutput.new( - request_state.request_id, - request_state.prompt, - request_state.prompt_token_ids, - detokenizer_output.output_text, - detokenizer_output.token_ids, - detokenizer_output.finished, + request_id=request_state.request_id, + prompt=request_state.prompt, + prompt_token_ids=request_state.prompt_token_ids, + text=detokenizer.get_next_output_text(finished, delta), + token_ids=new_token_ids if delta else detokenizer.output_token_ids, + logprobs=logprobs, + prompt_logprobs=prompt_logprobs, + cumulative_logprob=logprobs_processor.cumulative_logprob, + finished=finished, ) - if detokenizer_output.finished: + if finished: completion_output = request_output.outputs[0] - completion_output.finish_reason = str( - detokenizer_output.finish_reason) - completion_output.stop_reason = detokenizer_output.stop_reason + completion_output.finish_reason = str(finish_reason) + completion_output.stop_reason = stop_reason return request_output diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 366287951ed04..70876b03a8236 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -33,6 +33,7 @@ def __init__( ): self.model_config = model_config + self.cache_config = cache_config self.lora_config = lora_config self.tokenizer = tokenizer @@ -51,6 +52,37 @@ def __init__( self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \ cache_config.enable_prefix_caching + def _validate_logprobs( + self, + params: Union[SamplingParams, PoolingParams], + ) -> None: + if not isinstance(params, SamplingParams): + return + + max_logprobs = self.model_config.max_logprobs + # Validate sample logprobs. + if params.logprobs and params.logprobs > max_logprobs: + raise ValueError( + f"Requested sample logprobs of {params.logprobs}, " + f"which is greater than max allowed: {max_logprobs}") + + # Validate prompt logprobs. + if params.prompt_logprobs and params.prompt_logprobs > max_logprobs: + raise ValueError( + f"Requested prompt logprobs of {params.prompt_logprobs}, " + f"which is greater than max allowed: {max_logprobs}") + + # TODO(andy): enable this in follow up by recomputing. + if (params.prompt_logprobs is not None + and self.cache_config.enable_prefix_caching): + raise ValueError("Prefix caching with prompt logprobs not yet " + "supported on VLLM V1.") + + def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + def process_inputs( self, request_id: str, @@ -64,12 +96,11 @@ def process_inputs( ) -> EngineCoreRequest: # TODO(woosuk): Support pooling models. - # TODO(woosuk): Check max_logprobs # TODO(woosuk): Support encoder-decoder models. - if lora_request is not None and not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") + self._validate_logprobs(params) + self._validate_lora(lora_request) + if arrival_time is None: arrival_time = time.time() assert priority == 0, "vLLM V1 does not support priority at the moment." diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index ac10d43eb0d54..093be09ae11bb 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -25,15 +25,14 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]: parallel_config = vllm_config.parallel_config distributed_executor_backend = ( parallel_config.distributed_executor_backend) - if distributed_executor_backend is None: - # If the user does not specify the distributed executor backend, - # we will choose the backend based on the world size. - if parallel_config.world_size > 1: - distributed_executor_backend = "mp" - else: - distributed_executor_backend = "uni" - - if distributed_executor_backend == "ray": + # distributed_executor_backend must be set in VllmConfig.__post_init__ + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, ExecutorBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorBase. Got {distributed_executor_backend}.") + executor_class = distributed_executor_backend + elif distributed_executor_backend == "ray": executor_class = RayDistributedExecutor elif distributed_executor_backend == "mp": from vllm.v1.executor.multiproc_executor import MultiprocExecutor diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index eb1acf584c6b0..3472761dc1808 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -9,6 +9,7 @@ from vllm.config import ModelConfig from vllm.logger import init_logger +from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason from vllm.v1.metrics.stats import IterationStats, SchedulerStats @@ -37,6 +38,9 @@ def _reset(self, now): self.num_prompt_tokens: List[int] = [] self.num_generation_tokens: List[int] = [] + # Prefix cache metrics. TODO: Make the interval configurable. + self.prefix_caching_metrics = PrefixCachingMetrics() + def _local_interval_elapsed(self, now: float) -> bool: # Log every _LOCAL_LOGGING_INTERVAL_SEC. elapsed_time = now - self.last_log_time @@ -58,6 +62,8 @@ def log(self, scheduler_stats: SchedulerStats, self._track_iteration_stats(iteration_stats) + self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) + now = time.monotonic() if not self._local_interval_elapsed(now): return @@ -72,13 +78,15 @@ def log(self, scheduler_stats: SchedulerStats, logger.info( "Avg prompt throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s, " - "Running: %d reqs, Waiting: %d reqs " - "GPU KV cache usage: %.1f%%.", + "Running: %d reqs, Waiting: %d reqs, " + "GPU KV cache usage: %.1f%%, " + "Prefix cache hit rate: %.1f%%", prompt_throughput, generation_throughput, scheduler_stats.num_running_reqs, scheduler_stats.num_waiting_reqs, scheduler_stats.gpu_cache_usage * 100, + self.prefix_caching_metrics.hit_rate * 100, ) @@ -107,6 +115,18 @@ def __init__(self, model_config: ModelConfig): documentation="GPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames).labels(*labelvalues) + self.counter_gpu_prefix_cache_queries = prometheus_client.Counter( + name="vllm:gpu_prefix_cache_queries", + documentation= + "GPU prefix cache queries, in terms of number of queried blocks.", + labelnames=labelnames).labels(*labelvalues) + + self.counter_gpu_prefix_cache_hits = prometheus_client.Counter( + name="vllm:gpu_prefix_cache_hits", + documentation= + "GPU prefix cache hits, in terms of number of cached blocks.", + labelnames=labelnames).labels(*labelvalues) + self.counter_prompt_tokens = prometheus_client.Counter( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", @@ -170,6 +190,11 @@ def log(self, scheduler_stats: SchedulerStats, self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) + self.counter_gpu_prefix_cache_queries.inc( + scheduler_stats.prefix_cache_stats.queries) + self.counter_gpu_prefix_cache_hits.inc( + scheduler_stats.prefix_cache_stats.hits) + self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens) self.counter_generation_tokens.inc( iteration_stats.num_generation_tokens) diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index e3f1efcc9b1a7..f806b0adf5d5a 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import time -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, List if TYPE_CHECKING: @@ -9,6 +9,20 @@ from vllm.v1.engine import EngineCoreOutput, FinishReason +@dataclass +class PrefixCacheStats: + """Stores prefix cache hit statistics.""" + # Whether reset_prefix_cache was invoked. + reset: bool = False + # The number of requests in this update. + requests: int = 0 + # The number of queries in these requests. Note that "queries" here + # means the number of blocks that were queried from the cache. + queries: int = 0 + # The number of hits in these requests. + hits: int = 0 + + @dataclass class SchedulerStats: """Stats associated with the scheduler.""" @@ -17,7 +31,9 @@ class SchedulerStats: num_waiting_reqs: int = 0 gpu_cache_usage: float = 0.0 - # gpu_prefix_cache_hit_rate: float = 0.0 + + prefix_cache_stats: PrefixCacheStats = field( + default_factory=PrefixCacheStats) @dataclass @@ -60,14 +76,17 @@ def update_from_output(self, output: "EngineCoreOutput", self.num_generation_tokens += num_new_generation_tokens if is_prefilling: - # This relies on the invariant that EngineCore does - # not stream outputs for partially completed prefills - # (scheduler.update_from_output makes EngineCoreOutput - # iff num_computed_tokens == num_tokens). - assert (num_new_generation_tokens > 0) - self.num_prompt_tokens += prompt_len - - self.time_to_first_tokens_iter.append(last_token_latency) + # TODO(andy): we used to assert that num_new_generation_tokens + # > 0 with an invariant that EngineCore does not stream outputs + # for partially completed prefills (scheduler.update_from_output + # makes EngineCoreOutput iff num_computed_tokens == num_tokens). + # When prompt logprobs are enabled, we currently stream out the + # partially completed prompt. + # This will be reverted in a follow up PR and we should re-enable + # this assertion / invariant. + if num_new_generation_tokens > 0: + self.num_prompt_tokens += prompt_len + self.time_to_first_tokens_iter.append(last_token_latency) else: self.time_per_output_tokens_iter.append(last_token_latency) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 6e82bffd7e5c9..27fd2dbda8b28 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -1,25 +1,51 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List, NamedTuple, Optional import torch -@dataclass -class SamplerOutput: +class LogprobsLists(NamedTuple): + # [num_reqs, max_num_logprobs + 1] + logprob_token_ids: List[List[int]] + # [num_reqs, max_num_logprobs + 1] + logprobs: List[List[float]] # [num_reqs] - sampled_token_ids: torch.Tensor + sampled_token_ranks: List[int] + + def slice(self, start: int, end: int): + return LogprobsLists( + self.logprob_token_ids[start:end], + self.logprobs[start:end], + self.sampled_token_ranks[start:end], + ) + + +class LogprobsTensors(NamedTuple): # [num_reqs, max_num_logprobs + 1] - logprob_token_ids: Optional[torch.Tensor] + logprob_token_ids: torch.Tensor # [num_reqs, max_num_logprobs + 1] - logprobs: Optional[torch.Tensor] + logprobs: torch.Tensor + # [num_reqs] + selected_token_ranks: torch.Tensor - # TODO: Support prompt logprobs. - prompt_logprob_token_ids: Optional[torch.Tensor] - prompt_logprobs: Optional[torch.Tensor] + def tolists(self): + return LogprobsLists( + self.logprob_token_ids.tolist(), + self.logprobs.tolist(), + self.selected_token_ranks.tolist(), + ) + + +@dataclass +class SamplerOutput: + + # [num_reqs] + sampled_token_ids: torch.Tensor + logprobs_tensors: Optional[LogprobsTensors] # ModelRunnerOutput is serialized and sent to the scheduler process. @@ -36,6 +62,12 @@ class ModelRunnerOutput: sampled_token_ids: List[int] # [num_reqs, max_num_logprobs + 1] - logprob_token_ids_cpu: Optional[torch.Tensor] # [num_reqs, max_num_logprobs + 1] - logprobs_cpu: Optional[torch.Tensor] + # [num_reqs] + logprobs: Optional[LogprobsLists] + + # req_id -> (token_ids, logprobs, ranks) + # [prompt_len, num_prompt_logprobs] + # [prompt_len, num_prompt_logprobs] + # [prompt_len] + prompt_logprobs_dict: Dict[str, LogprobsTensors] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 89b39ea615d20..bb4d2c19197bc 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange - from vllm.v1.core.kv_cache_utils import BlockHashType class Request: @@ -63,11 +62,6 @@ def __init__( if self.mm_hashes: assert len(self.mm_inputs) == len(self.mm_hashes) - # Cache the computed kv block hashes of the request to avoid - # recomputing. - self._kv_block_hashes: List[BlockHashType] = [] - self.kv_block_hashes = ConstantList(self._kv_block_hashes) - # Read-only views # Prevent directly appending to the these lists since # they should also be updated simultaneously. @@ -124,13 +118,6 @@ def get_num_encoder_tokens(self, input_id: int) -> int: num_tokens = self.mm_positions[input_id]["length"] return num_tokens - def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None: - self._kv_block_hashes = value - self.kv_block_hashes = ConstantList(self._kv_block_hashes) - - def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None: - self._kv_block_hashes.append(block_hash) - class RequestStatus(enum.IntEnum): """Status of a request.""" diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 8e54de34548dd..1a2771baba963 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -20,7 +20,8 @@ class SamplingMetadata: generators: Dict[int, torch.Generator] - max_num_logprobs: int + # None means no logprobs, 0 means sampled token logprobs only + max_num_logprobs: Optional[int] no_penalties: bool prompt_token_ids: Optional[torch.Tensor] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 3da7498e0dae5..43fd64aaaa828 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 """A layer that samples the next tokens from the model's outputs.""" -from typing import Tuple import torch import torch.nn as nn -from vllm.v1.outputs import SamplerOutput +from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.penalties import (apply_all_penalties, apply_min_token_penalties) @@ -25,20 +24,16 @@ def forward( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - needs_logprobs = sampling_metadata.max_num_logprobs > 0 - if needs_logprobs: - # NOTE(woosuk): Use the original logits (before any penalties or - # temperature scaling) for the top-k logprobs. - # This is different from the V0 sampler, which uses the logits that - # is used for sampling (after penalties and temperature scaling). - # NOTE: We compute logprobs first because the below ops may - # modify the logits tensor in-place (and we don't want to clone - # the logits tensor for memory efficiency). - topk_logprobs, topk_indices = self.get_topk_logprobs( - logits, sampling_metadata) - else: - topk_logprobs = None - topk_indices = None + + # NOTE(woosuk): Use the original logits (before any penalties or + # temperature scaling) for the top-k logprobs. + # This is different from the V0 sampler, which uses the logits that + # is used for sampling (after penalties and temperature scaling). + # TODO(rob): provide option for logprobs post sampling. + # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501 + num_logprobs = sampling_metadata.max_num_logprobs + if num_logprobs is not None: + raw_logprobs = self.compute_logprobs(logits) # Use float32 for the logits. logits = logits.to(torch.float32) @@ -48,15 +43,19 @@ def forward( logits = self.apply_temperature(logits, sampling_metadata.temperature) # Sample the next token. sampled = self.sample(logits, sampling_metadata) + + # Gather the logprobs of the topk and sampled token (if requested). + # Get logprobs and rank tensors (if requested) + logprobs_tensors = None if num_logprobs is None else \ + self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled) + # Use int32 to reduce the tensor size. sampled = sampled.to(torch.int32) + # These are GPU tensors. sampler_output = SamplerOutput( sampled_token_ids=sampled, - logprob_token_ids=topk_indices, - logprobs=topk_logprobs, - prompt_logprob_token_ids=None, - prompt_logprobs=None, + logprobs_tensors=logprobs_tensors, ) return sampler_output @@ -103,19 +102,52 @@ def sample( ) return sampled - def get_topk_logprobs( + def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: + return logits.log_softmax(dim=-1, dtype=torch.float32) + + def gather_logprobs( self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Tuple[torch.Tensor, torch.Tensor]: - logprobs = logits.log_softmax(dim=-1, dtype=torch.float32) - # FIXME: Mask the sampled token_id, get topk logprobs, - # and concatenate the topk with the sampled token_id. - topk_logprobs, topk_indices = torch.topk( - logprobs, sampling_metadata.max_num_logprobs, dim=-1) + logprobs: torch.Tensor, + num_logprobs: int, + token_ids: torch.Tensor, + ) -> LogprobsTensors: + """ + Gather logprobs for topk and sampled/prompt token. + + Args: + logits: (num tokens) x (vocab) tensor + num_logprobs: minimum number of logprobs to + retain per token + token_ids: prompt tokens (if prompt logprobs) + or sampled tokens (if sampled + logprobs); 1D token ID tensor + with (num tokens) elements + + Returns: + Top-k int indices tensor, (num tokens) x (num_logprobs + 1) + Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) + Sampled token rank tensor, (num tokens) + """ + # Find the topK values. + topk_logprobs, topk_indices = torch.topk(logprobs, + num_logprobs, + dim=-1) + + # Get with the logprob of the prompt or sampled token. + token_ids = token_ids.unsqueeze(-1) + token_logprobs = logprobs.gather(-1, token_ids) + + # Compute the ranks of the actual token. + token_ranks = (logprobs >= token_logprobs).sum(-1) + + # Concatenate together with the topk. + indices = torch.cat((token_ids, topk_indices), dim=1) + logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1) + # Use int32 to reduce the tensor size. - topk_indices = topk_indices.to(torch.int32) - return topk_logprobs, topk_indices + indices = indices.to(torch.int32) + + return LogprobsTensors(indices, logprobs, token_ranks) def apply_penalties( self, diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 1791dfa2b6325..3f000abcde0d1 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -1,12 +1,53 @@ # SPDX-License-Identifier: Apache-2.0 import pickle +from typing import Any, Optional +import torch +from msgspec import msgpack -class PickleEncoder: +CUSTOM_TYPE_TENSOR = 1 +CUSTOM_TYPE_PICKLE = 2 - def encode(self, obj): - return pickle.dumps(obj) - def decode(self, data): +class MsgpackEncoder: + """Encoder with custom torch tensor serialization.""" + + def __init__(self): + self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook) + + def encode(self, obj: Any) -> bytes: + return self.encoder.encode(obj) + + def encode_into(self, obj: Any, buf: bytearray) -> None: + self.encoder.encode_into(obj, buf) + + +class MsgpackDecoder: + """Decoder with custom torch tensor serialization.""" + + def __init__(self, t: Optional[Any] = None): + args = () if t is None else (t, ) + self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook) + + def decode(self, obj: Any): + return self.decoder.decode(obj) + + +def custom_enc_hook(obj: Any) -> Any: + if isinstance(obj, torch.Tensor): + # NOTE(rob): it is fastest to use numpy + pickle + # when serializing torch tensors. + # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501 + return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy())) + + return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj)) + + +def custom_ext_hook(code: int, data: memoryview) -> Any: + if code == CUSTOM_TYPE_TENSOR: + return torch.from_numpy(pickle.loads(data)) + if code == CUSTOM_TYPE_PICKLE: return pickle.loads(data) + + raise NotImplementedError(f"Extension type code {code} is not supported") diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index a31e888656166..d5b8fd2184156 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -176,7 +176,9 @@ def __init__( self.generators: Dict[int, torch.Generator] = {} self.num_logprobs: Dict[str, int] = {} - self.prompt_logprob_reqs: Set[str] = set() + # NOTE(rob): num_prompt_logprobs only includes reqs + # that are currently in the prefill phase. + self.num_prompt_logprobs: Dict[str, int] = {} def add_request( self, @@ -238,11 +240,10 @@ def add_request( if request.generator is not None: self.generators[req_index] = request.generator - num_logprobs = sampling_params.logprobs - if num_logprobs is not None and num_logprobs > 0: - self.num_logprobs[req_id] = num_logprobs - if sampling_params.prompt_logprobs: - self.prompt_logprob_reqs.add(req_id) + if sampling_params.logprobs is not None: + self.num_logprobs[req_id] = sampling_params.logprobs + if sampling_params.prompt_logprobs is not None: + self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs # Add request lora ID if request.lora_request: @@ -272,7 +273,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.repetition_penalties_reqs.discard(req_id) self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) - self.prompt_logprob_reqs.discard(req_id) + self.num_prompt_logprobs.pop(req_id, None) # LoRA lora_id = self.request_lora_mapping[req_index] @@ -297,7 +298,7 @@ def clear(self) -> None: self.repetition_penalties_reqs.clear() self.generators.clear() self.num_logprobs.clear() - self.prompt_logprob_reqs.clear() + self.num_prompt_logprobs.clear() self.request_lora_mapping.fill(0) self.lora_id_to_lora_request.clear() self.lora_id_to_request_ids.clear() @@ -489,13 +490,9 @@ def no_penalties(self) -> bool: and len(self.repetition_penalties_reqs) == 0) @property - def max_num_logprobs(self) -> int: - return max(self.num_logprobs.values()) if self.num_logprobs else 0 - - @property - def no_logprob(self) -> bool: - return len(self.num_logprobs) == 0 + def max_num_logprobs(self) -> Optional[int]: + return max(self.num_logprobs.values()) if self.num_logprobs else None @property def no_prompt_logprob(self) -> bool: - return len(self.prompt_logprob_reqs) == 0 + return not self.num_prompt_logprobs diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bfc9d1ca83f45..9b1eab613bf7b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -29,14 +29,14 @@ from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput + from vllm.v1.core.scheduler_output import SchedulerOutput logger = init_logger(__name__) @@ -92,6 +92,7 @@ def __init__( # Multi-modal data support self.input_registry = INPUT_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY + self.uses_mrope = model_config.uses_mrope # NOTE: Initialized input mapper is only used for processing dummy # multimodal data into multimodal kwargs for GPU memory profiling. @@ -147,7 +148,7 @@ def __init__( device=self.device) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.model_config.uses_mrope: + if self.uses_mrope: # NOTE: `mrope_positions` is implemented with one additional dummy # position on purpose to make it non-contiguous so that it can work # with torch compile. @@ -284,7 +285,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.model_config.uses_mrope: + if self.uses_mrope: image_grid_thw = [] video_grid_thw = [] second_per_grid_ts = [] @@ -411,7 +412,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.model_config.uses_mrope: + if self.uses_mrope: self._calc_mrope_positions(scheduler_output) # Get token indices. @@ -458,7 +459,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) - if self.model_config.uses_mrope: + if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions[:, :total_num_scheduled_tokens].copy_( self.mrope_positions_cpu[:, :total_num_scheduled_tokens], @@ -476,67 +477,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.device, non_blocking=True).long() # Prepare for cascade attention if needed. - common_prefix_len = (scheduler_output.num_common_prefix_blocks * - self.block_size) - if common_prefix_len == 0: - # Common case. - use_cascade = False - else: - # NOTE(woosuk): Cascade attention uses two attention kernels: one - # for the common prefix and the other for the rest. For the first - # kernel, we concatenate all the query tokens (possibly from - # different requests) and treat them as if they are from the same - # request. Then, we use bi-directional attention to process the - # common prefix in the KV cache. Importantly, this means that the - # first kernel does not do any masking. - - # Consider the following example: - # Request 1's input query: [D, E, X] - # Request 1's kv cache: [A, B, C, D, E, X] - # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) - # Request 2's input query: [E, Y] - # Request 2's kv cache: [A, B, C, D, E, Y] - # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) - - # If we use [A, B, C, D, E] as the common prefix, then the - # first kernel will compute the bi-directional attention between - # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. - # However, this is wrong because D in Request 1 should not attend to - # E in the common prefix (i.e., we need masking). - # To avoid this, [A, B, C, D] should be the common prefix. - # That is, the common prefix should be capped by the minimum - # num_computed_tokens among the requests, and plus one to include - # the first token of the query. - - # In practice, we use [A, B, C] as the common prefix, instead of - # [A, B, C, D] (i.e., the common prefix is capped by the minimum - # num_computed_tokens, without plus one). - # This is because of an implementation detail: We want to always - # use two kernels for cascade attention. Let's imagine: - # Request 3's input query: [D] - # Request 3's kv cache: [A, B, C, D] - # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D]) - # If we use [A, B, C, D] as the common prefix for Request 1-3, - # then Request 3 will be processed only by the first kernel, - # and the second kernel will get an empty input. While this is not - # a fundamental problem, our current implementation does not support - # this case. - common_prefix_len = min( - common_prefix_len, - self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) - # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // self.block_size * - self.block_size) - use_cascade = FlashAttentionBackend.use_cascade_attention( - common_prefix_len=common_prefix_len, - query_lens=num_scheduled_tokens, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - use_alibi=False, # FIXME - use_sliding_window=self.sliding_window is not None, - num_sms=self.num_sms, - ) - + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + scheduler_output.num_common_prefix_blocks, + ) + use_cascade = common_prefix_len > 0 if use_cascade: # TODO: Optimize. cu_prefix_query_lens = torch.tensor( @@ -581,6 +526,90 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): logits_indices = query_start_loc[1:] - 1 return attn_metadata, logits_indices + def _compute_cascade_attn_prefix_len( + self, + num_scheduled_tokens: np.ndarray, + num_common_prefix_blocks: int, + ) -> int: + """Compute the length of the common prefix for cascade attention. + + NOTE(woosuk): The common prefix length returned by this function + represents the length used specifically for cascade attention, not the + actual number of tokens shared between requests. When cascade attention + is disabled (use_cascade=False), this function returns 0 even if + requests share common tokens. Additionally, the common prefix length is + truncated to a multiple of the block size and may be further truncated + due to implementation details explained below. + + Args: + num_scheduled_tokens: Number of tokens scheduled per request. + num_common_prefix_blocks: Number of shared KV cache blocks. + + Returns: + int: Length of common prefix in tokens. + """ + common_prefix_len = num_common_prefix_blocks * self.block_size + if common_prefix_len == 0: + # Common case. + return 0 + + # NOTE(woosuk): Cascade attention uses two attention kernels: one + # for the common prefix and the other for the rest. For the first + # kernel, we concatenate all the query tokens (possibly from + # different requests) and treat them as if they are from the same + # request. Then, we use bi-directional attention to process the + # common prefix in the KV cache. Importantly, this means that the + # first kernel does not do any masking. + + # Consider the following example: + # Request 1's input query: [D, E, X] + # Request 1's kv cache: [A, B, C, D, E, X] + # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) + # Request 2's input query: [E, Y] + # Request 2's kv cache: [A, B, C, D, E, Y] + # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) + + # If we use [A, B, C, D, E] as the common prefix, then the + # first kernel will compute the bi-directional attention between + # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. + # However, this is wrong because D in Request 1 should not attend to + # E in the common prefix (i.e., we need masking). + # To avoid this, [A, B, C, D] should be the common prefix. + # That is, the common prefix should be capped by the minimum + # num_computed_tokens among the requests, and plus one to include + # the first token of the query. + + # In practice, we use [A, B, C] as the common prefix, instead of + # [A, B, C, D] (i.e., the common prefix is capped by the minimum + # num_computed_tokens, without plus one). + # This is because of an implementation detail: We want to always + # use two kernels for cascade attention. Let's imagine: + # Request 3's input query: [D] + # Request 3's kv cache: [A, B, C, D] + # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D]) + # If we use [A, B, C, D] as the common prefix for Request 1-3, + # then Request 3 will be processed only by the first kernel, + # and the second kernel will get an empty input. While this is not + # a fundamental problem, our current implementation does not support + # this case. + num_reqs = len(num_scheduled_tokens) + common_prefix_len = min( + common_prefix_len, + self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + # common_prefix_len should be a multiple of the block size. + common_prefix_len = (common_prefix_len // self.block_size * + self.block_size) + use_cascade = FlashAttentionBackend.use_cascade_attention( + common_prefix_len=common_prefix_len, + query_lens=num_scheduled_tokens, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + use_alibi=False, # FIXME + use_sliding_window=self.sliding_window is not None, + num_sms=self.num_sms, + ) + return common_prefix_len if use_cascade else 0 + def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr = 0 num_reqs = self.input_batch.num_reqs @@ -789,13 +818,14 @@ def execute_model( # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None + if self.uses_mrope: + positions = self.mrope_positions[:, :num_input_tokens] + else: + positions = self.positions[:num_input_tokens] # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config): - positions = self.mrope_positions[:, :num_input_tokens] \ - if self.model_config.uses_mrope \ - else self.positions[:num_input_tokens] hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -804,8 +834,8 @@ def execute_model( inputs_embeds=inputs_embeds, ) hidden_states = hidden_states[:num_scheduled_tokens] - hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(hidden_states, None) + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) # Sample the next token and get logprobs if needed. sampling_metadata = self._prepare_sampling(batch_changed) @@ -818,7 +848,8 @@ def execute_model( # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] - for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + for i, req_id in enumerate( # type: ignore[assignment] + self.input_batch.req_ids[:num_reqs]): assert req_id is not None req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + @@ -847,27 +878,28 @@ def execute_model( # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. sampled_token_ids = sampler_output.sampled_token_ids.tolist() + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states, + scheduler_output, + ) + # Update with the actual token ids for i, req_state, seq_len in request_seq_lens: token_id = sampled_token_ids[i] self.input_batch.token_ids_cpu[i, seq_len] = token_id req_state.output_token_ids[-1] = token_id - if sampler_output.logprob_token_ids is None: - logprob_token_ids = None - else: - logprob_token_ids = sampler_output.logprob_token_ids.cpu() - if sampler_output.logprobs is None: - logprobs = None - else: - logprobs = sampler_output.logprobs.cpu() - model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=sampled_token_ids, - logprob_token_ids_cpu=logprob_token_ids, - logprobs_cpu=logprobs, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, ) return model_runner_output @@ -886,6 +918,76 @@ def load_model(self) -> None: logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) + def _get_prompt_logprobs_dict( + self, + hidden_states: torch.Tensor, + scheduler_output: "SchedulerOutput", + ) -> Dict[str, LogprobsTensors]: + num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs + if not num_prompt_logprobs_dict: + return {} + + prompt_logprobs_dict: Dict[str, LogprobsTensors] = {} + + # Since prompt logprobs are a rare feature, prioritize simple, + # maintainable loop over optimal performance. + completed_prefill_reqs = [] + for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): + + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + + # Get metadata for this request. + request = self.requests[req_id] + num_prompt_tokens = len(request.prompt_token_ids) + prompt_token_ids = torch.tensor(request.prompt_token_ids).to( + self.device, non_blocking=True) + + # Determine number of logits to retrieve. + start_tok = request.num_computed_tokens + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens < num_remaining_tokens: + # This is a chunk, more tokens remain. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(req_id) + + # Get the logits corresponding to this req's prompt tokens. + # If this is a partial request (i.e. chunked prefill), + # then there is prompt logprob generated for each index. + req_idx = self.input_batch.req_id_to_index[req_id] + offset = self.query_start_loc_np[req_idx].item() + prompt_hidden_states = hidden_states[offset:offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states, None) + + # Get the "target" tokens for each index. For prompt at index i, + # the token at prompt index i+1 is the "sampled" token we want + # to gather the logprob for. + tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] + + # Compute prompt logprobs. + logprobs = self.model.sampler.compute_logprobs(logits) + token_ids, logprobs, ranks = self.model.sampler.gather_logprobs( + logprobs, num_prompt_logprobs, tgt_token_ids) + + # Transfer GPU->CPU async. + prompt_logprobs_dict[req_id] = LogprobsTensors( + token_ids.to("cpu", non_blocking=True), + logprobs.to("cpu", non_blocking=True), + ranks.to("cpu", non_blocking=True), + ) + + # Remove requests that have completed prefill from the batch + # num_prompt_logprobs_dict. + for req_id in completed_prefill_reqs: + del num_prompt_logprobs_dict[req_id] + + # Must synchronize the non-blocking GPU->CPU transfers. + torch.cuda.synchronize() + + return prompt_logprobs_dict + @torch.inference_mode() def _dummy_run( self, @@ -901,10 +1003,11 @@ def _dummy_run( else: input_ids = self.input_ids[:num_tokens] inputs_embeds = None + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] with set_forward_context(None, self.vllm_config): - positions = self.mrope_positions[:, :num_tokens] \ - if self.model_config.uses_mrope \ - else self.positions[:num_tokens] hidden_states = model( input_ids=input_ids, positions=positions, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0adb69073397c..ad53f90b86652 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -18,7 +18,6 @@ from vllm.model_executor import set_random_seed from vllm.platforms import current_platform from vllm.utils import GiB_bytes -from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -26,7 +25,7 @@ logger = init_logger(__name__) if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput + from vllm.v1.core.scheduler_output import SchedulerOutput class Worker: diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index b846d4387ba58..774049a5281ee 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -639,12 +639,25 @@ def load_model(self) -> None: "Bias support in LoRA is not enabled in HPU yet." assert not self.lora_config.fully_sharded_loras, \ "Fully sharded LoRAs is not enabled in HPU yet." + # It's necessary to distinguish between the + # max_position_embeddings of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = ( + self.model.config.max_position_embeddings) + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, - self.vocab_size, self.lora_config, self.device, + self.vocab_size, + self.lora_config, + self.device, self.model.embedding_modules, - self.model.embedding_padding_modules) + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) self.model = self.lora_manager.create_lora_manager(self.model) if self.model_config.quantization == 'inc': diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 12baecde6e42c..c7814f17375b2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -98,7 +98,6 @@ class ModelInputForGPU(ModelRunnerInputBase): finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 async_callback: Optional[Callable] = None - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None scheduler_outputs: Optional[SchedulerOutputs] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: