diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 36a7ee37fd228..23c962a3d7075 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -52,6 +52,7 @@ RUN pip install xformers==0.0.23 --no-deps RUN cd /app \ && cd vllm \ && pip install -U -r requirements-rocm.txt \ + && bash patch_rocm.rocm.sh \ && bash patch_xformers.rocm.sh \ && python3 setup.py install \ && cd .. diff --git a/README.md b/README.md index 8ea4d029dc64f..85c5068b69cb9 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,20 @@ +multi-lora rocm development + +Derived from [Yard1's multi-lora branch](https://github.com/Yard1/vllm/tree/multi_lora) + +[Important note] + +Starting from ROCm v5.7, some type conversion functions on bfloat16 are included and implemented in header files. Unfortunately a few of the host functions are not specified as inline or static, so building the project on ROCm directly would result in ODR violations when the translation units are being linked. + +A way to circumvent this is to manually add the `inline` or `static` keywards to the related functions. In the `rocm/pytorch` container that `Dockerfile.rocm` builds from, it means adding the keyword `inline` to `/opt/rocm/include/hip/amd_detail/amd_hip_bf16.h:96` so that the line becomes + +```cpp +L96: #define __HOST_DEVICE__ __host__ __device__ inline +``` + +This is far from a pretty solution though. Even though it appears that [ROCm is fixing this](https://github.com/ROCm/clr/commit/86bd518981b364c138f9901b28a529899d8654f3), it doesn't seem to be included in ROCm 6.0.0. Fixes like this may need to stay around until better solutions come out. + +

diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e33d5fb2dc247..d75d690cc66d4 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -65,7 +65,9 @@ def run_to_completion(profile_dir: Optional[str] = None): if args.profile: profile_dir = args.profile_result_dir if not profile_dir: - profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}" + profile_dir = Path( + "." + ) / "vllm_benchmark_result" / f"latency_result_{time.time()}" print(f"Profiling (results will be saved to '{profile_dir}')...") run_to_completion(profile_dir=args.profile_result_dir) return @@ -123,9 +125,7 @@ def run_to_completion(profile_dir: Optional[str] = None): '--profile-result-dir', type=str, default=None, - help=( - 'path to save the pytorch profiler output. Can be visualized ' - 'with ui.perfetto.dev or Tensorboard.' - )) + help=('path to save the pytorch profiler output. Can be visualized ' + 'with ui.perfetto.dev or Tensorboard.')) args = parser.parse_args() main(args) diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index aa58dd73c148a..1eef4c34607f0 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -18,6 +18,12 @@ #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) #endif +#ifndef USE_ROCM + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta) +#else + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) +#endif + #ifndef USE_ROCM #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) diff --git a/csrc/punica/LICENSE b/csrc/punica/LICENSE new file mode 100644 index 0000000000000..a46e2cdcadf7d --- /dev/null +++ b/csrc/punica/LICENSE @@ -0,0 +1,217 @@ +Contains code from https://github.com/punica-ai/punica + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------ + +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses/ +for text of these licenses. + + +Apache-2.0 +* third_party/nvbench (with LLVM exception) +* third_party/flashinfer + +BSD-3-Clause: +* third_party/cutlass \ No newline at end of file diff --git a/csrc/punica/bgmv/bgmv_all.cu b/csrc/punica/bgmv/bgmv_all.cu new file mode 100644 index 0000000000000..2502a67e3c813 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_all.cu @@ -0,0 +1,21 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h new file mode 100644 index 0000000000000..664dddc680ab6 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_config.h @@ -0,0 +1,63 @@ +#pragma once + +template +void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale); + +// clang-format off + +#define FOR_BGMV_WIDE_exc128(f, in_T, out_T, W_T, narrow) \ + f(in_T, out_T, W_T, narrow, 256) \ + f(in_T, out_T, W_T, narrow, 512) \ + f(in_T, out_T, W_T, narrow, 1024) \ + f(in_T, out_T, W_T, narrow, 1280) \ + f(in_T, out_T, W_T, narrow, 1728) \ + f(in_T, out_T, W_T, narrow, 1792) \ + f(in_T, out_T, W_T, narrow, 2048) \ + f(in_T, out_T, W_T, narrow, 2560) \ + f(in_T, out_T, W_T, narrow, 2752) \ + f(in_T, out_T, W_T, narrow, 3072) \ + f(in_T, out_T, W_T, narrow, 3456) \ + f(in_T, out_T, W_T, narrow, 3584) \ + f(in_T, out_T, W_T, narrow, 4096) \ + f(in_T, out_T, W_T, narrow, 5120) \ + f(in_T, out_T, W_T, narrow, 5504) \ + f(in_T, out_T, W_T, narrow, 6912) \ + f(in_T, out_T, W_T, narrow, 7168) \ + f(in_T, out_T, W_T, narrow, 8192) \ + f(in_T, out_T, W_T, narrow, 9216) \ + f(in_T, out_T, W_T, narrow, 10240) \ + f(in_T, out_T, W_T, narrow, 11008) \ + f(in_T, out_T, W_T, narrow, 12288) \ + f(in_T, out_T, W_T, narrow, 13824) \ + f(in_T, out_T, W_T, narrow, 14336) \ + f(in_T, out_T, W_T, narrow, 16384) \ + f(in_T, out_T, W_T, narrow, 20480) \ + f(in_T, out_T, W_T, narrow, 28672) \ + f(in_T, out_T, W_T, narrow, 32000) \ + f(in_T, out_T, W_T, narrow, 32256) \ + f(in_T, out_T, W_T, narrow, 32512) \ + f(in_T, out_T, W_T, narrow, 32768) \ + f(in_T, out_T, W_T, narrow, 33024) \ + f(in_T, out_T, W_T, narrow, 36864) \ + f(in_T, out_T, W_T, narrow, 49152) \ +// Keep above in sync with vllm/lora/layers::SamplerWithLoRA + +#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \ + f(in_T, out_T, W_T, narrow, 128) \ + FOR_BGMV_WIDE_exc128(f, in_T, out_T, W_T, narrow) \ +// Keep above in sync with vllm/lora/layers::SamplerWithLoRA + +// Keep this in sync with vllm/config::LoRAConfig +#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) \ + FOR_BGMV_WIDE_exc128(f, in_T, out_T, W_T, 128) + +// clang-format on diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh new file mode 100644 index 0000000000000..2e72394647c6a --- /dev/null +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -0,0 +1,372 @@ +#pragma once + +#include +#ifndef USE_ROCM +#include +#else +#include +#endif +#ifndef USE_ROCM +#include +#endif +#include +#include +#include + +#include "vec_dtypes.cuh" + +namespace cg = cooperative_groups; + +#ifdef USE_ROCM + +template +__host__ __device__ +inline void* memcpy_blocking(void *dst, const void *src) { + // Does not handle the case of long datatypes + char *d = reinterpret_cast(dst); + const char *s = reinterpret_cast(src); + size_t i = 0; +#pragma unroll + for (i = 0; i < len; ++i) { + d[i] = s[i]; + } + return dst; +} +#endif + +// nthrs = (32, 4) +template +__global__ void +bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + if (idx < 0) { + return; + } + + auto block = cg::this_thread_block(); + size_t j = blockIdx.x; + constexpr size_t num_pipeline_stages = 2; + constexpr size_t tile_size = tx * ty * vec_size; + __shared__ W_T W_shared[num_pipeline_stages * tile_size]; + __shared__ in_T X_shared[num_pipeline_stages * tile_size]; + __shared__ float y_warpwise[ty]; + + size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; + size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; +#ifndef USE_ROCM + auto pipe = cuda::make_pipeline(); + + // pipeline load W/X and compute WX; + pipe.producer_acquire(); + cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(W_copy_size), pipe); + cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(X_copy_size), pipe); + pipe.producer_commit(); +#else + memcpy_blocking(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + (threadIdx.y * tx + threadIdx.x) * vec_size); + memcpy_blocking(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + + (threadIdx.y * tx + threadIdx.x) * vec_size); +#endif + size_t copy_idx, compute_idx; + float y = 0.f; + vec_t x_vec; + vec_t w_vec; + size_t tile_idx; + +#pragma unroll + for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size; + ++tile_idx) { + copy_idx = tile_idx % num_pipeline_stages; + // pipeline stage: async copy W fragment +#ifndef USE_ROCM + pipe.producer_acquire(); + if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { + cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(W_copy_size), pipe); + cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(X_copy_size), pipe); + } + pipe.producer_commit(); +#else + if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { + memcpy_blocking(W_shared + W_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size); + memcpy_blocking(X_shared + X_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size); + } +#endif + + compute_idx = (tile_idx - 1) % num_pipeline_stages; + // pipeline stage: compute WX +#ifndef USE_ROCM + pipe.consumer_wait(); +#endif + block.sync(); + x_vec.load(X_shared + X_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W_shared + W_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { +#ifndef USE_ROCM + sum += float(w_vec[i]) * float(x_vec[i]) * scale; +#else + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; +#endif + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += VLLM_SHFL_DOWN_SYNC(sum, offset); + } + y_warpwise[threadIdx.y] = sum; + block.sync(); +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y += y_warpwise[i]; + } + + block.sync(); +#ifndef USE_ROCM + pipe.consumer_release(); +#endif + } + + compute_idx = (tile_idx - 1) % num_pipeline_stages; + // final pipeline stage +#ifndef USE_ROCM + pipe.consumer_wait(); +#endif + block.sync(); + x_vec.load(X_shared + X_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W_shared + W_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { +#ifndef USE_ROCM + sum += float(w_vec[i]) * float(x_vec[i]) * scale; +#else + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; +#endif + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += VLLM_SHFL_DOWN_SYNC(sum, offset); + } + y_warpwise[threadIdx.y] = + ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) + ? sum + : 0.f; + block.sync(); +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y += y_warpwise[i]; + } + + block.sync(); +#ifndef USE_ROCM + pipe.consumer_release(); +#endif + + // write Y; + if (block.thread_rank() == 0) { +#ifndef USE_ROCM + Y[batch_idx * full_y_size + y_offset + j] += static_cast(y); +#else + size_t y_idx = batch_idx * full_y_size + y_offset + j; + Y[y_idx] = vllm_add(Y[y_idx], convert_type(y)); +#endif + } +} + +// nthrs = (2, 16, 4) +template +__global__ void +bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + + if (idx < 0) { + return; + } + + auto block = cg::this_thread_block(); + size_t tile_idx = blockIdx.x; + + // load X; + vec_t x_vec; + x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size); + + // load W; + vec_t w_vec; + w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in + + block.thread_rank() * vec_size); + + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { +#ifndef USE_ROCM + sum += float(w_vec[i]) * float(x_vec[i]) * scale; +#else + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; +#endif + } + + cg::thread_block_tile g = cg::tiled_partition(block); +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += g.shfl_down(sum, offset); + } + sum = g.shfl(sum, 0); + + if (threadIdx.x == 0) { +#ifndef USE_ROCM + Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + + threadIdx.z * ty + threadIdx.y] + += static_cast(sum); +#else + size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + + threadIdx.z * ty + threadIdx.y; + Y[y_idx] = vllm_add(Y[y_idx], convert_type(sum)); +#endif + } +} + +template +void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale) { + constexpr size_t vec_size = 8; + constexpr int tz = 4; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if constexpr (feat_in < feat_out) { + static_assert(feat_in % vec_size == 0); + constexpr int tx = feat_in / vec_size; + + static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) || + (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) || + (8 % tx == 0 && feat_out % (8 / tx * tz) == 0)); + + if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) { + constexpr int ty = 32 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) { + constexpr int ty = 16 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else { + constexpr int ty = 8 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } + } else { + static_assert(feat_in % (vec_size * 32) == 0 || + feat_in % (vec_size * 16) == 0 || + feat_in % (vec_size * 8) == 0); + + if constexpr (feat_in % (vec_size * 32) == 0) { + constexpr int tx = 32; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) { + constexpr int tx = 32; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) { + constexpr int tx = 16; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } + } +} + +#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \ + template void bgmv_kernel( \ + out_T * __restrict__ Y, const in_T *__restrict__ X, \ + const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \ + int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ + int64_t num_layers, int64_t layer_idx, float scale); + +#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ + INST_BGMV(narrow, wide, in_T, out_T, W_T) \ + INST_BGMV(wide, narrow, in_T, out_T, W_T) diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh new file mode 100644 index 0000000000000..2738892e6dc4a --- /dev/null +++ b/csrc/punica/bgmv/vec_dtypes.cuh @@ -0,0 +1,1325 @@ +#ifndef VEC_DTYPES_CUH_ +#define VEC_DTYPES_CUH_ + +#ifdef FLASHINFER_USE_FP8 +#include +#endif +#include + +#include + +#include "../type_convert.h" +#include "../../cuda_compat.h" + +#define FLASHINFER_INLINE \ + inline __attribute__((always_inline)) __device__ __host__ + +template +struct vec_t { + FLASHINFER_INLINE float_t &operator[](size_t i); + FLASHINFER_INLINE const float_t &operator[](size_t i) const; + FLASHINFER_INLINE void fill(float_t val); + FLASHINFER_INLINE void load(const float_t *ptr); + FLASHINFER_INLINE void store(float_t *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src); + template + FLASHINFER_INLINE void cast_load(const T *ptr); + template + FLASHINFER_INLINE void cast_store(T *ptr) const; + FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src); +}; + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = tgt_float_t(src[i]); + } +} + +template +FLASHINFER_INLINE void cast_load_impl(const src_float_t *src_ptr, + vec_t &dst) { + if constexpr (std::is_same::value) { + dst.load(src_ptr); + } else { + vec_t tmp; + tmp.load(src_ptr); + dst.cast_from(tmp); + } +} + +template +FLASHINFER_INLINE void cast_store_impl(const vec_t &src, + tgt_float_t *dst_ptr) { + if constexpr (std::is_same::value) { + src.store(dst_ptr); + } else { + vec_t tmp; + tmp.cast_from(src); + tmp.store(dst_ptr); + } +} + +#ifdef FLASHINFER_USE_FP8 +/******************* vec_t<__nv_fp8_e4m3> *******************/ + +// __nv_fp8_e4m3 x 1 +template <> +struct vec_t<__nv_fp8_e4m3, 1> { + __nv_fp8_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3 *ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store( + __nv_fp8_e4m3 *ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *dst = *src; +} + +// __nv_fp8_e4m3 x 2 +template <> +struct vec_t<__nv_fp8_e4m3, 2> { + __nv_fp8x2_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) { + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3 *ptr) { + data = *((__nv_fp8x2_e4m3 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store( + __nv_fp8_e4m3 *ptr) const { + *((__nv_fp8x2_e4m3 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *((__nv_fp8x2_e4m3 *)dst) = *((__nv_fp8x2_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 4 + +template <> +struct vec_t<__nv_fp8_e4m3, 4> { + __nv_fp8x4_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) { + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3 *ptr) { + data = *((__nv_fp8x4_e4m3 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store( + __nv_fp8_e4m3 *ptr) const { + *((__nv_fp8x4_e4m3 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *((__nv_fp8x4_e4m3 *)dst) = *((__nv_fp8x4_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 8 + +template <> +struct vec_t<__nv_fp8_e4m3, 8> { + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) { + ((__nv_fp8x4_e4m3 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3 *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store( + __nv_fp8_e4m3 *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *((__nv_fp8_e4m3 *)dst) = *((__nv_fp8_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 16 or more +template +struct vec_t<__nv_fp8_e4m3, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)data)[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e4m3 *)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t<__nv_fp8_e5m2> *******************/ + +// __nv_fp8_e5m2 x 1 +template <> +struct vec_t<__nv_fp8_e5m2, 1> { + __nv_fp8_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2 *ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store( + __nv_fp8_e5m2 *ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *dst = *src; +} + +// __nv_fp8_e5m2 x 2 +template <> +struct vec_t<__nv_fp8_e5m2, 2> { + __nv_fp8x2_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) { + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2 *ptr) { + data = *((__nv_fp8x2_e5m2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store( + __nv_fp8_e5m2 *ptr) const { + *((__nv_fp8x2_e5m2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *((__nv_fp8x2_e5m2 *)dst) = *((__nv_fp8x2_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 4 + +template <> +struct vec_t<__nv_fp8_e5m2, 4> { + __nv_fp8x4_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) { + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2 *ptr) { + data = *((__nv_fp8x4_e5m2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store( + __nv_fp8_e5m2 *ptr) const { + *((__nv_fp8x4_e5m2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *((__nv_fp8x4_e5m2 *)dst) = *((__nv_fp8x4_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 8 + +template <> +struct vec_t<__nv_fp8_e5m2, 8> { + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) { + ((__nv_fp8x4_e5m2 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2 *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store( + __nv_fp8_e5m2 *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *((__nv_fp8_e5m2 *)dst) = *((__nv_fp8_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 16 or more + +template +struct vec_t<__nv_fp8_e5m2, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)data)[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e5m2 *)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; +#endif + +/******************* vec_t *******************/ + +// half x 1 +template <> +struct vec_t { + half data; + + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { + *dst = *src; +} + +// half x 2 +template <> +struct vec_t { + half2 data; + + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { + data = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { + data = *((half2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { + *((half2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { + *((half2 *)dst) = *((half2 *)src); +} + +// half x 4 + +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { + *(half2 *)(&data.x) = make_half2(val, val); + *(half2 *)(&data.y) = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { + *((uint2 *)dst) = *((uint2 *)src); +} + +// half x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)data)[i]; + } + FLASHINFER_INLINE void fill(half val) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + *(half2 *)(&(data[i].x)) = make_half2(val, val); + *(half2 *)(&(data[i].y)) = make_half2(val, val); + *(half2 *)(&(data[i].z)) = make_half2(val, val); + *(half2 *)(&(data[i].w)) = make_half2(val, val); + } + } + FLASHINFER_INLINE void load(const half *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(half *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// nv_bfloat16 x 1 +template <> +struct vec_t { + nv_bfloat16 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { + *dst = *src; +} + +// nv_bfloat16 x 2 +template <> +struct vec_t { + nv_bfloat162 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + data = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { + data = *((nv_bfloat162 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { + *((nv_bfloat162 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { + *((nv_bfloat162 *)dst) = *((nv_bfloat162 *)src); +} + +// nv_bfloat16 x 4 + +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + *(nv_bfloat162 *)(&data.x) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&data.y) = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { + *((uint2 *)dst) = *((uint2 *)src); +} + +// nv_bfloat16 x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)data)[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)data)[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + *(nv_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val); + } + } + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// float x 1 + +template <> +struct vec_t { + float data; + + FLASHINFER_INLINE float &operator[](size_t i) { + return ((float *)(&data))[i]; + } + FLASHINFER_INLINE const float &operator[](size_t i) const { + return ((const float *)(&data))[i]; + } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float *ptr); + FLASHINFER_INLINE void store(float *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const float *ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(float *ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { + *dst = *src; +} + +// float x 2 + +template <> +struct vec_t { + float2 data; + + FLASHINFER_INLINE float &operator[](size_t i) { + return ((float *)(&data))[i]; + } + FLASHINFER_INLINE const float &operator[](size_t i) const { + return ((const float *)(&data))[i]; + } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float *ptr); + FLASHINFER_INLINE void store(float *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) { + data = make_float2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const float *ptr) { + data = *((float2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(float *ptr) const { + *((float2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { + *((float2 *)dst) = *((float2 *)src); +} + +// float x 4 or more +template +struct vec_t { + float4 data[vec_size / 4]; + + FLASHINFER_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; } + FLASHINFER_INLINE const float &operator[](size_t i) const { + return ((const float *)(data))[i]; + } + FLASHINFER_INLINE void fill(float val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = make_float4(val, val, val, val); + } + } + FLASHINFER_INLINE void load(const float *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = ((float4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(float *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + FLASHINFER_INLINE static void memcpy(float *dst, const float *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)dst)[i] = ((float4 *)src)[i]; + } + } +}; + +/******************* vec_t type cast *******************/ + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2 *)(&dst.data))[i] = __half22float2(((half2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = half(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)(&dst.data))[i] = __float22half2_rn(((float2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2 *)(&dst.data))[i] = + __bfloat1622float2(((nv_bfloat162 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = nv_bfloat16(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162 *)(&dst.data))[i] = + __float22bfloat162_rn(((float2 *)(&src.data))[i]); + } + } +} + +#ifdef FLASHINFER_USE_FP8 + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else if constexpr (vec_size == 2) { + *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e4m3 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e4m3, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(float2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = + __nv_fp8x4_e4m3(((float4 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e4m3, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(half2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + // NOTE(Zihao): need to double check if we properly handle flo and fhi + ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = __nv_fp8x4_e4m3( + ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else if constexpr (vec_size == 2) { + *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e5m2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e5m2, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e5m2(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(float2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = + __nv_fp8x4_e5m2(((float4 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e5m2, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(half2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + // NOTE(Zihao): need to double check if we properly handle flo and fhi + ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = __nv_fp8x4_e5m2( + ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); + } + } +} + +#endif // FLASHINFER_USE_FP8 + +#endif // VEC_DTYPES_CUH_ diff --git a/csrc/punica/punica_ops.cu b/csrc/punica/punica_ops.cu new file mode 100644 index 0000000000000..644740d9c49b0 --- /dev/null +++ b/csrc/punica/punica_ops.cu @@ -0,0 +1,550 @@ +#include +#include + +#include + +#include "type_convert.h" +#include "../cuda_compat.h" +#include "bgmv/bgmv_config.h" + +//====== utils ====== + +inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, + const char *a_name, const char *b_name) { + TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", + a.dim(), " vs ", b.dim()); + for (int i = 0; i < a.dim(); ++i) { + TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, + ".size(", i, ")"); + } +} + +inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { + return (uint32_t(a) << 16) | uint32_t(b); +} + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) \ + TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) + +#define CHECK_EQ(a, b) \ + TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +//====== bgmv ====== + +template +inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, + const int64_t *lora_indices, + uint16_t in_features, uint16_t out_features, + int64_t y_offset, int64_t full_y_size, + int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale) { + switch (pack_u16(in_features, out_features)) { +#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ + case pack_u16(feat_in, feat_out): \ + bgmv_kernel(Y, X, W, lora_indices, y_offset, \ + full_y_size, batch_size, num_layers, \ + layer_idx, scale); \ + break; +#define CASE(_in_T, _out_T, _W_T, narrow, wide) \ + CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \ + CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) + + FOR_BGMV_WIDE_NARROW(CASE, _, _, _) +#undef CASE +#undef CASE_ONESIDE + default: + return false; + } + + return true; +} + +void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, float scale) { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w); + CHECK_INPUT(indicies); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(4, w); + CHECK_DIM(1, indicies); + + int64_t B = x.size(0); + int64_t h_in = x.size(1); + int64_t h_out = y.size(1); + int64_t num_layers = w.size(1); + CHECK_EQ(w.size(3), h_in); + CHECK_EQ(w.size(2), h_out); + CHECK_EQ(indicies.size(0), x.size(0)); + CHECK_EQ(y.size(0), x.size(0)); + bool ok = false; + if (h_in < 65536 && h_out < 65536) { + // TODO: See if we can get rid of this massive nested switch + switch (x.scalar_type()) { + case at::ScalarType::Half: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + default: + break; + } + } + TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, + " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); +} + +void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, + float scale, int64_t h_in, int64_t h_out, + int64_t y_offset) { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w); + CHECK_INPUT(indicies); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(4, w); + CHECK_DIM(1, indicies); + + int64_t B = x.size(0); + int64_t num_layers = w.size(1); + int64_t full_y_size = y.size(1); + CHECK_EQ(w.size(3), h_in); + CHECK_EQ(w.size(2), h_out); + CHECK_EQ(indicies.size(0), x.size(0)); + CHECK_EQ(y.size(0), x.size(0)); + bool ok = false; + if (h_in < 65536 && h_out < 65536) { + // TODO: See if we can get rid of this massive nested switch + switch (x.scalar_type()) { + case at::ScalarType::Half: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + default: + break; + } + } + TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, + " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); +} diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h new file mode 100644 index 0000000000000..937e2d1d25d4a --- /dev/null +++ b/csrc/punica/punica_ops.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, float scale); + +void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, + float scale, int64_t h_in, int64_t h_out, + int64_t y_offset); diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/punica_pybind.cpp new file mode 100644 index 0000000000000..4435496619a29 --- /dev/null +++ b/csrc/punica/punica_pybind.cpp @@ -0,0 +1,13 @@ +#include + +#include "punica_ops.h" + +//====== pybind ====== + +#define DEFINE_pybind(name) m.def(#name, &name, #name); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); + m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, + "dispatch_bgmv_low_level"); +} \ No newline at end of file diff --git a/csrc/punica/type_convert.h b/csrc/punica/type_convert.h new file mode 100644 index 0000000000000..dff7ce49283d7 --- /dev/null +++ b/csrc/punica/type_convert.h @@ -0,0 +1,82 @@ +#ifndef CSRC__PUNICA__TYPE_CONVERT_H__ +#define CSRC__PUNICA__TYPE_CONVERT_H__ + +#ifndef USE_ROCM + +#include +#include + +#else + +#include +#include + +#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__ + +typedef __half nv_half; +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) { + return __hip_bfloat162{val, val}; +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) { + return __hip_bfloat162{vall, valr}; +} + +template +__TYPE_CONVERT__HOST_DEVICE__ +inline T_dst convert_type(T_src val) { + return static_cast(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline float convert_type<__half, float>(__half val) { + return __half2float(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __half convert_type(float val) { + return __float2half(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 convert_type(float val) { + return __float2bfloat16(val); +} + +template +__TYPE_CONVERT__HOST_DEVICE__ +inline T vllm_add(T a, T b) { + return a + b; +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __half vllm_add<__half>(__half a, __half b) { + return __hadd(a, b); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) { + return __hadd(a, b); +} + +#undef __TYPE_CONVERT__HOST_DEVICE__ + +#endif // USE_ROCM + +#endif // CSRC__PUNICA__TYPE_CONVERT_H__ diff --git a/examples/multilora_inference.py b/examples/multilora_inference.py new file mode 100644 index 0000000000000..8fdd243af69ff --- /dev/null +++ b/examples/multilora_inference.py @@ -0,0 +1,117 @@ +""" +This example shows how to use the multi-LoRA functionality for offline inference. + +Requires HuggingFace credentials for access to Llama2. +""" + +from typing import Optional, List, Tuple + +from huggingface_hub import snapshot_download + +from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput +from vllm.lora.request import LoRARequest + + +def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]: + """Create a list of test prompts with their sampling parameters. + + 2 requests for base model, 4 requests for the LoRA. We define 2 + different LoRA adapters (using the same model for demo purposes). + Since we also set `max_loras=1`, the expectation is that the requests + with the second LoRA adapter will be ran after all requests with the + first adapter have finished. + """ + return [ + ("A robot may not injure a human being", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128), None), + ("To be or not to be,", + SamplingParams(temperature=0.8, + top_k=5, + presence_penalty=0.2, + max_tokens=128), None), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", + SamplingParams(n=3, + best_of=3, + use_beam_search=True, + temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora2", 2, lora_path)), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", + SamplingParams(n=3, + best_of=3, + use_beam_search=True, + temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ] + + +def process_requests(engine: LLMEngine, + test_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]]): + """Continuously process a list of prompts and handle the outputs.""" + request_id = 0 + + while test_prompts or engine.has_unfinished_requests(): + if test_prompts: + prompt, sampling_params, lora_request = test_prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + lora_request=lora_request) + request_id += 1 + + request_outputs: List[RequestOutput] = engine.step() + + for request_output in request_outputs: + if request_output.finished: + print(request_output) + + +def initialize_engine() -> LLMEngine: + """Initialize the LLMEngine.""" + # max_loras: controls the number of LoRAs that can be used in the same + # batch. Larger numbers will cause higher memory usage, as each LoRA + # slot requires its own preallocated tensor. + # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger + # numbers will cause higher memory usage. If you know that all LoRAs will + # use the same rank, it is recommended to set this as low as possible. + # max_cpu_loras: controls the size of the CPU LoRA cache. + engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", + enable_lora=True, + max_loras=1, + max_lora_rank=8, + max_cpu_loras=2, + max_num_seqs=256) + return LLMEngine.from_engine_args(engine_args) + + +def main(): + """Main function that sets up and runs the prompt processing.""" + engine = initialize_engine() + lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + test_prompts = create_test_prompts(lora_path) + process_requests(engine, test_prompts) + + +if __name__ == '__main__': + main() diff --git a/patch_rocm.rocm.sh b/patch_rocm.rocm.sh new file mode 100644 index 0000000000000..fc7570e6086c7 --- /dev/null +++ b/patch_rocm.rocm.sh @@ -0,0 +1,27 @@ +#!/bin/bash +set -e + +if [ -z "$ROCM_PATH" ]; then + echo "Could not determine ROCm installation path by ROCM_PATH. Abort HIP patching" + exit 1 +fi + +export __HIP_FILE_TO_PATCH="$ROCM_PATH/include/hip/amd_detail/amd_hip_bf16.h" +export __HIP_PATCH_FILE="./rocm_patch/rocm__amd_bf16.patch" + +if [ ! -f "$__HIP_FILE_TO_PATCH" ]; then + echo "Could not find the file to be patched in $__HIP_FILE_TO_PATCH. Abort HIP patching" + exit 2 +fi + +echo "File to be patched: $__HIP_FILE_TO_PATCH" + +if ! patch -R -p0 -s -f --dry-run $__HIP_FILE_TO_PATCH $__HIP_PATCH_FILE; then + echo "Applying patch to ${__HIP_FILE_TO_PATCH}" + patch -p0 $__HIP_FILE_TO_PATCH $__HIP_PATCH_FILE + echo "Successfully patched ${__HIP_FILE_TO_PATCH}" +else + echo "${__HIP_FILE_TO_PATCH} has been patched before" +fi + +exit 0 \ No newline at end of file diff --git a/rocm_patch/rocm__amd_bf16.patch b/rocm_patch/rocm__amd_bf16.patch new file mode 100644 index 0000000000000..dbdb8002169fb --- /dev/null +++ b/rocm_patch/rocm__amd_bf16.patch @@ -0,0 +1,975 @@ +--- amd_hip_bf16_ori.h 2024-01-05 06:45:45.451392253 +0000 ++++ amd_hip_bf16.h 2024-01-05 06:44:22.164682921 +0000 +@@ -85,17 +85,31 @@ + #ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BF16_H_ + #define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BF16_H_ + ++#if !defined(__HIPCC_RTC__) ++#include ++#endif // !defined(__HIPCC_RTC__) ++ + #include "amd_hip_vector_types.h" // float2 etc + #include "device_library_decls.h" // ocml conversion functions + #include "math_fwd.h" // ocml device functions + + #if defined(__HIPCC_RTC__) +-#define __HOST_DEVICE__ __device__ ++#define __HOST_DEVICE__ __device__ static + #else ++#include + #include +-#define __HOST_DEVICE__ __host__ __device__ ++#include ++#define __HOST_DEVICE__ __host__ __device__ static inline + #endif + ++#define HIPRT_ONE_BF16 __float2bfloat16(1.0f) ++#define HIPRT_ZERO_BF16 __float2bfloat16(0.0f) ++#define HIPRT_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U) ++#define HIPRT_MAX_NORMAL_BF16 __ushort_as_bfloat16((unsigned short)0x7F7FU) ++#define HIPRT_MIN_DENORM_BF16 __ushort_as_bfloat16((unsigned short)0x0001U) ++#define HIPRT_NAN_BF16 __ushort_as_bfloat16((unsigned short)0x7FFFU) ++#define HIPRT_NEG_ZERO_BF16 __ushort_as_bfloat16((unsigned short)0x8000U) ++ + // Since we are using unsigned short to represent data in bfloat16, it can be of different sizes on + // different machines. These naive checks should prevent some undefined behavior on systems which + // have different sizes for basic types. +@@ -185,7 +199,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Moves bfloat16 value to bfloat162 + */ +-__device__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfloat16 a) { ++__HOST_DEVICE__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfloat16 a) { + return __hip_bfloat162{a, a}; + } + +@@ -193,13 +207,13 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Reinterprets bits in a __hip_bfloat16 as a signed short integer + */ +-__device__ short int __bfloat16_as_short(const __hip_bfloat16 h) { return (short)h.data; } ++__HOST_DEVICE__ short int __bfloat16_as_short(const __hip_bfloat16 h) { return (short)h.data; } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Reinterprets bits in a __hip_bfloat16 as an unsigned signed short integer + */ +-__device__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) { return h.data; } ++__HOST_DEVICE__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) { return h.data; } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV +@@ -221,7 +235,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Combine two __hip_bfloat16 to __hip_bfloat162 + */ +-__device__ __hip_bfloat162 __halves2bfloat162(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ __hip_bfloat162 __halves2bfloat162(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __hip_bfloat162{a, b}; + } + +@@ -229,13 +243,13 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Returns high 16 bits of __hip_bfloat162 + */ +-__device__ __hip_bfloat16 __high2bfloat16(const __hip_bfloat162 a) { return a.y; } ++__HOST_DEVICE__ __hip_bfloat16 __high2bfloat16(const __hip_bfloat162 a) { return a.y; } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Returns high 16 bits of __hip_bfloat162 + */ +-__device__ __hip_bfloat162 __high2bfloat162(const __hip_bfloat162 a) { ++__HOST_DEVICE__ __hip_bfloat162 __high2bfloat162(const __hip_bfloat162 a) { + return __hip_bfloat162{a.y, a.y}; + } + +@@ -249,7 +263,8 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Extracts high 16 bits from each and combines them + */ +-__device__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat162 a, ++ const __hip_bfloat162 b) { + return __hip_bfloat162{a.y, b.y}; + } + +@@ -257,13 +272,13 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Returns low 16 bits of __hip_bfloat162 + */ +-__device__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) { return a.x; } ++__HOST_DEVICE__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) { return a.x; } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Returns low 16 bits of __hip_bfloat162 + */ +-__device__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) { ++__HOST_DEVICE__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) { + return __hip_bfloat162{a.x, a.x}; + } + +@@ -277,7 +292,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Swaps both halves + */ +-__device__ __hip_bfloat162 __lowhigh2highlow(const __hip_bfloat162 a) { ++__HOST_DEVICE__ __hip_bfloat162 __lowhigh2highlow(const __hip_bfloat162 a) { + return __hip_bfloat162{a.y, a.x}; + } + +@@ -285,7 +300,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Extracts low 16 bits from each and combines them + */ +-__device__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hip_bfloat162{a.x, b.x}; + } + +@@ -293,7 +308,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Reinterprets short int into a bfloat16 + */ +-__device__ __hip_bfloat16 __short_as_bfloat16(const short int a) { ++__HOST_DEVICE__ __hip_bfloat16 __short_as_bfloat16(const short int a) { + return __hip_bfloat16{(unsigned short)a}; + } + +@@ -301,7 +316,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Reinterprets unsigned short int into a bfloat16 + */ +-__device__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) { ++__HOST_DEVICE__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) { + return __hip_bfloat16{a}; + } + +@@ -310,7 +325,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Adds two bfloat16 values + */ +-__device__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b)); + } + +@@ -318,7 +333,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Subtracts two bfloat16 values + */ +-__device__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __float2bfloat16(__bfloat162float(a) - __bfloat162float(b)); + } + +@@ -326,7 +341,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Divides two bfloat16 values + */ +-__device__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __float2bfloat16(__bfloat162float(a) / __bfloat162float(b)); + } + +@@ -344,7 +359,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Multiplies two bfloat16 values + */ +-__device__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b)); + } + +@@ -352,7 +367,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Negate a bfloat16 value + */ +-__device__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) { ++__HOST_DEVICE__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) { + auto ret = a; + ret.data ^= 0x8000; + return ret; +@@ -362,7 +377,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Returns absolute of a bfloat16 + */ +-__device__ __hip_bfloat16 __habs(const __hip_bfloat16 a) { ++__HOST_DEVICE__ __hip_bfloat16 __habs(const __hip_bfloat16 a) { + auto ret = a; + ret.data &= 0x7FFF; + return ret; +@@ -372,7 +387,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Divides bfloat162 values + */ +-__device__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hip_bfloat162{__float2bfloat16(__bfloat162float(a.x) / __bfloat162float(b.x)), + __float2bfloat16(__bfloat162float(a.y) / __bfloat162float(b.y))}; + } +@@ -381,7 +396,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Returns absolute of a bfloat162 + */ +-__device__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) { ++__HOST_DEVICE__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) { + return __hip_bfloat162{__habs(a.x), __habs(a.y)}; + } + +@@ -389,7 +404,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Adds two bfloat162 values + */ +-__device__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hip_bfloat162{__hadd(a.x, b.x), __hadd(a.y, b.y)}; + } + +@@ -406,7 +421,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Multiplies two bfloat162 values + */ +-__device__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hip_bfloat162{__hmul(a.x, b.x), __hmul(a.y, b.y)}; + } + +@@ -414,7 +429,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Converts a bfloat162 into negative + */ +-__device__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) { ++__HOST_DEVICE__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) { + return __hip_bfloat162{__hneg(a.x), __hneg(a.y)}; + } + +@@ -422,15 +437,251 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Subtracts two bfloat162 values + */ +-__device__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hip_bfloat162{__hsub(a.x, b.x), __hsub(a.y, b.y)}; + } + + /** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to multiply two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator*(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hmul(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to multiply-assign two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16& operator*=(__hip_bfloat16& l, const __hip_bfloat16& r) { ++ l = __hmul(l, r); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to unary+ on a __hip_bfloat16 number ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator+(const __hip_bfloat16& l) { return l; } ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to add two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator+(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hadd(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to negate a __hip_bfloat16 number ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator-(const __hip_bfloat16& l) { return __hneg(l); } ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to subtract two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator-(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hsub(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to post increment a __hip_bfloat16 number ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator++(__hip_bfloat16& l, const int) { ++ auto ret = l; ++ l = __hadd(l, HIPRT_ONE_BF16); ++ return ret; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to pre increment a __hip_bfloat16 number ++ */ ++__HOST_DEVICE__ __hip_bfloat16& operator++(__hip_bfloat16& l) { ++ l = __hadd(l, HIPRT_ONE_BF16); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to post decrement a __hip_bfloat16 number ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator--(__hip_bfloat16& l, const int) { ++ auto ret = l; ++ l = __hsub(l, HIPRT_ONE_BF16); ++ return ret; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to pre decrement a __hip_bfloat16 number ++ */ ++__HOST_DEVICE__ __hip_bfloat16& operator--(__hip_bfloat16& l) { ++ l = __hsub(l, HIPRT_ONE_BF16); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to add-assign two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16& operator+=(__hip_bfloat16& l, const __hip_bfloat16& r) { ++ l = __hadd(l, r); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to subtract-assign two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16& operator-=(__hip_bfloat16& l, const __hip_bfloat16& r) { ++ l = __hsub(l, r); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to divide two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator/(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hdiv(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to divide-assign two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16& operator/=(__hip_bfloat16& l, const __hip_bfloat16& r) { ++ l = __hdiv(l, r); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to multiply two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator*(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hmul2(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to multiply-assign two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162& operator*=(__hip_bfloat162& l, const __hip_bfloat162& r) { ++ l = __hmul2(l, r); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to unary+ on a __hip_bfloat162 number ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator+(const __hip_bfloat162& l) { return l; } ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to add two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator+(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hadd2(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to negate a __hip_bfloat162 number ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator-(const __hip_bfloat162& l) { return __hneg2(l); } ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to subtract two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator-(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hsub2(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to post increment a __hip_bfloat162 number ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator++(__hip_bfloat162& l, const int) { ++ auto ret = l; ++ l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); ++ return ret; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to pre increment a __hip_bfloat162 number ++ */ ++__HOST_DEVICE__ __hip_bfloat162& operator++(__hip_bfloat162& l) { ++ l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to post decrement a __hip_bfloat162 number ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator--(__hip_bfloat162& l, const int) { ++ auto ret = l; ++ l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); ++ return ret; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to pre decrement a __hip_bfloat162 number ++ */ ++__HOST_DEVICE__ __hip_bfloat162& operator--(__hip_bfloat162& l) { ++ l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to add-assign two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162& operator+=(__hip_bfloat162& l, const __hip_bfloat162& r) { ++ l = __hadd2(l, r); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to subtract-assign two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162& operator-=(__hip_bfloat162& l, const __hip_bfloat162& r) { ++ l = __hsub2(l, r); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to divide two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator/(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __h2div(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to divide-assign two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162& operator/=(__hip_bfloat162& l, const __hip_bfloat162& r) { ++ l = __h2div(l, r); ++ return l; ++} ++ ++/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values + */ +-__device__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __bfloat162float(a) == __bfloat162float(b); + } + +@@ -438,7 +689,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered equal + */ +-__device__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return !(__bfloat162float(a) < __bfloat162float(b)) && + !(__bfloat162float(a) > __bfloat162float(b)); + } +@@ -447,7 +698,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - greater than + */ +-__device__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __bfloat162float(a) > __bfloat162float(b); + } + +@@ -455,7 +706,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered greater than + */ +-__device__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return !(__bfloat162float(a) <= __bfloat162float(b)); + } + +@@ -463,7 +714,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - greater than equal + */ +-__device__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __bfloat162float(a) >= __bfloat162float(b); + } + +@@ -471,7 +722,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered greater than equal + */ +-__device__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return !(__bfloat162float(a) < __bfloat162float(b)); + } + +@@ -479,7 +730,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - not equal + */ +-__device__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __bfloat162float(a) != __bfloat162float(b); + } + +@@ -487,7 +738,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered not equal + */ +-__device__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return !(__bfloat162float(a) == __bfloat162float(b)); + } + +@@ -495,23 +746,31 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - return max + */ +-__device__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++#if __HIP_DEVICE_COMPILE__ + return __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a), __bfloat162float(b))); ++#else ++ return __float2bfloat16(std::max(__bfloat162float(a), __bfloat162float(b))); ++#endif + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - return min + */ +-__device__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++#if __HIP_DEVICE_COMPILE__ + return __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a), __bfloat162float(b))); ++#else ++ return __float2bfloat16(std::min(__bfloat162float(a), __bfloat162float(b))); ++#endif + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - less than operator + */ +-__device__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __bfloat162float(a) < __bfloat162float(b); + } + +@@ -519,15 +778,15 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered less than + */ +-__device__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return !(__bfloat162float(a) >= __bfloat162float(b)); + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP +- * \brief Compare two bfloat162 values - less than ++ * \brief Compare two bfloat162 values - less than equal + */ +-__device__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __bfloat162float(a) <= __bfloat162float(b); + } + +@@ -535,7 +794,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered less than equal + */ +-__device__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return !(__bfloat162float(a) > __bfloat162float(b)); + } + +@@ -543,19 +802,33 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Checks if number is inf + */ +-__device__ int __hisinf(const __hip_bfloat16 a) { return __ocml_isinf_f32(__bfloat162float(a)); } ++__HOST_DEVICE__ int __hisinf(const __hip_bfloat16 a) { ++ unsigned short sign = a.data & 0x8000U; ++#if __HIP_DEVICE_COMPILE__ ++ int res = __ocml_isinf_f32(__bfloat162float(a)); ++#else ++ int res = std::isinf(__bfloat162float(a)) ? 1 : 0; ++#endif ++ return (res == 0) ? res : ((sign != 0U) ? -res : res); ++} + + /** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Checks if number is nan + */ +-__device__ bool __hisnan(const __hip_bfloat16 a) { return __ocml_isnan_f32(__bfloat162float(a)); } ++__HOST_DEVICE__ bool __hisnan(const __hip_bfloat16 a) { ++#if __HIP_DEVICE_COMPILE__ ++ return __ocml_isnan_f32(__bfloat162float(a)); ++#else ++ return std::isnan(__bfloat162float(a)); ++#endif ++} + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Checks if two numbers are equal + */ +-__device__ bool __hbeq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbeq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __heq(a.x, b.x) && __heq(a.y, b.y); + } + +@@ -563,7 +836,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Checks if two numbers are equal - unordered + */ +-__device__ bool __hbequ2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbequ2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hequ(a.x, b.x) && __hequ(a.y, b.y); + } + +@@ -571,7 +844,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a >= b + */ +-__device__ bool __hbge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hge(a.x, b.x) && __hge(a.y, b.y); + } + +@@ -579,7 +852,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a >= b - unordered + */ +-__device__ bool __hbgeu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbgeu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hgeu(a.x, b.x) && __hgeu(a.y, b.y); + } + +@@ -587,7 +860,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a > b + */ +-__device__ bool __hbgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hgt(a.x, b.x) && __hgt(a.y, b.y); + } + +@@ -595,7 +868,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a > b - unordered + */ +-__device__ bool __hbgtu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbgtu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hgtu(a.x, b.x) && __hgtu(a.y, b.y); + } + +@@ -603,7 +876,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a <= b + */ +-__device__ bool __hble2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hble2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hle(a.x, b.x) && __hle(a.y, b.y); + } + +@@ -611,7 +884,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a <= b - unordered + */ +-__device__ bool __hbleu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbleu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hleu(a.x, b.x) && __hleu(a.y, b.y); + } + +@@ -619,7 +892,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a < b + */ +-__device__ bool __hblt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hblt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hlt(a.x, b.x) && __hlt(a.y, b.y); + } + +@@ -627,7 +900,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a < b - unordered + */ +-__device__ bool __hbltu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbltu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hltu(a.x, b.x) && __hltu(a.y, b.y); + } + +@@ -635,7 +908,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a != b + */ +-__device__ bool __hbne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hne(a.x, b.x) && __hne(a.y, b.y); + } + +@@ -643,7 +916,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a != b + */ +-__device__ bool __hbneu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbneu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hneu(a.x, b.x) && __hneu(a.y, b.y); + } + +@@ -651,84 +924,175 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a != b, returns 1.0 if equal, otherwise 0.0 + */ +-__device__ __hip_bfloat162 __heq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{{__heq(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, +- {__heq(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; ++__HOST_DEVICE__ __hip_bfloat162 __heq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{{__heq(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, ++ {__heq(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a >= b, returns 1.0 if greater than equal, otherwise 0.0 + */ +-__device__ __hip_bfloat162 __hge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{{__hge(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, +- {__hge(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; ++__HOST_DEVICE__ __hip_bfloat162 __hge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{{__hge(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, ++ {__hge(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a > b, returns 1.0 if greater than equal, otherwise 0.0 + */ +-__device__ __hip_bfloat162 __hgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{{__hgt(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, +- {__hgt(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; ++__HOST_DEVICE__ __hip_bfloat162 __hgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{{__hgt(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, ++ {__hgt(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a is NaN, returns 1.0 if NaN, otherwise 0.0 + */ +-__device__ __hip_bfloat162 __hisnan2(const __hip_bfloat162 a) { +- return __hip_bfloat162{ +- {__ocml_isnan_f32(__bfloat162float(a.x)) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, +- {__ocml_isnan_f32(__bfloat162float(a.y)) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; ++__HOST_DEVICE__ __hip_bfloat162 __hisnan2(const __hip_bfloat162 a) { ++ return __hip_bfloat162{{__hisnan(a.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, ++ {__hisnan(a.y) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a <= b, returns 1.0 if greater than equal, otherwise 0.0 + */ +-__device__ __hip_bfloat162 __hle2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{{__hle(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, +- {__hle(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; ++__HOST_DEVICE__ __hip_bfloat162 __hle2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{{__hle(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, ++ {__hle(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a < b, returns 1.0 if greater than equal, otherwise 0.0 + */ +-__device__ __hip_bfloat162 __hlt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{{__hlt(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, +- {__hlt(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; ++__HOST_DEVICE__ __hip_bfloat162 __hlt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{{__hlt(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, ++ {__hlt(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Returns max of two elements + */ +-__device__ __hip_bfloat162 __hmax2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{ +- __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a.x), __bfloat162float(b.x))), +- __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a.y), __bfloat162float(b.y)))}; ++__HOST_DEVICE__ __hip_bfloat162 __hmax2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{__hmax(a.x, b.x), __hmax(a.y, b.y)}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Returns min of two elements + */ +-__device__ __hip_bfloat162 __hmin2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{ +- __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a.x), __bfloat162float(b.x))), +- __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a.y), __bfloat162float(b.y)))}; ++__HOST_DEVICE__ __hip_bfloat162 __hmin2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{__hmin(a.x, b.x), __hmin(a.y, b.y)}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Checks for not equal to + */ +-__device__ __hip_bfloat162 __hne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{{__hne(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, +- {__hne(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; ++__HOST_DEVICE__ __hip_bfloat162 __hne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{{__hne(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, ++ {__hne(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP ++ * \brief Operator to perform an equal compare on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator==(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __heq(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP ++ * \brief Operator to perform a not equal on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator!=(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hne(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP ++ * \brief Operator to perform a less than on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator<(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hlt(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP ++ * \brief Operator to perform a less than equal on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator<=(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hle(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP ++ * \brief Operator to perform a greater than on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator>(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hgt(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP ++ * \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator>=(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hge(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_COMP ++ * \brief Operator to perform an equal compare on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator==(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __heq(l.x, r.x) && __heq(l.y, r.y); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_COMP ++ * \brief Operator to perform a not equal on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator!=(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hne(l.x, r.x) || __hne(l.y, r.y); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_COMP ++ * \brief Operator to perform a less than on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator<(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hlt(l.x, r.x) && __hlt(l.y, r.y); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_COMP ++ * \brief Operator to perform a less than equal on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator<=(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hle(l.x, r.x) && __hle(l.y, r.y); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_COMP ++ * \brief Operator to perform a greater than on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator>(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hgt(l.x, r.x) && __hgt(l.y, r.y); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP ++ * \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator>=(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hge(l.x, r.x) && __hge(l.y, r.y); + } + + /** +@@ -970,5 +1334,4 @@ + __device__ __hip_bfloat162 h2trunc(const __hip_bfloat162 h) { + return __hip_bfloat162{htrunc(h.x), htrunc(h.y)}; + } +- + #endif diff --git a/setup.py b/setup.py index 811d494e7a01f..22a6a251e5c23 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,16 @@ +import contextlib import io import os import re import subprocess -from typing import List, Set import warnings +from pathlib import Path +from typing import List, Set from packaging.version import parse, Version import setuptools import torch +import torch.utils.cpp_extension as torch_cpp_ext from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME ROOT_DIR = os.path.dirname(__file__) @@ -47,7 +50,7 @@ def _is_cuda() -> bool: ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] - +NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy() def get_amdgpu_offload_arch(): command = "/opt/rocm/llvm/bin/amdgpu-offload-arch" @@ -87,6 +90,11 @@ def get_hipcc_rocm_version(): return None +def glob(pattern: str): + root = Path(__name__).parent + return [str(p) for p in root.glob(pattern)] + + def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. @@ -188,6 +196,8 @@ def get_torch_arch_list() -> Set[str]: raise RuntimeError( "CUDA 11.8 or higher is required for compute capability 9.0.") + NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy() + # Add target compute capabilities to NVCC flags. for capability in compute_capabilities: num = capability[0] + capability[2] @@ -196,12 +206,30 @@ def get_torch_arch_list() -> Set[str]: NVCC_FLAGS += [ "-gencode", f"arch=compute_{num},code=compute_{num}" ] + if int(capability[0]) >= 8: + NVCC_FLAGS_PUNICA += ["-gencode", f"arch=compute_{num},code=sm_{num}"] + if capability.endswith("+PTX"): + NVCC_FLAGS_PUNICA += [ + "-gencode", f"arch=compute_{num},code=compute_{num}" + ] # Use NVCC threads to parallelize the build. if nvcc_cuda_version >= Version("11.2"): nvcc_threads = int(os.getenv("NVCC_THREADS", 8)) num_threads = min(os.cpu_count(), nvcc_threads) NVCC_FLAGS += ["--threads", str(num_threads)] + + # changes for punica kernels + NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS + REMOVE_NVCC_FLAGS = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_BFLOAT16_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', + ] + for flag in REMOVE_NVCC_FLAGS: + with contextlib.suppress(ValueError): + torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) elif _is_hip(): amd_arch = get_amdgpu_offload_arch() @@ -227,6 +255,18 @@ def get_torch_arch_list() -> Set[str]: if _is_cuda(): vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") +install_punica = bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "1"))) + +if _is_cuda(): + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 8: + install_punica = False + break +elif _is_hip(): + pass + vllm_extension = CUDAExtension( name="vllm._C", sources=vllm_extension_sources, @@ -237,6 +277,19 @@ def get_torch_arch_list() -> Set[str]: ) ext_modules.append(vllm_extension) +if install_punica: + ext_modules.append( + CUDAExtension( + name="vllm._punica_C", + sources=["csrc/punica/punica_ops.cu", + "csrc/punica/punica_pybind.cpp"] + + glob("csrc/punica/bgmv/*.cu"), + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS_PUNICA, + }, + )) + def get_path(*filepath) -> str: return os.path.join(ROOT_DIR, *filepath) diff --git a/tests/lora/__init__.py b/tests/lora/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py new file mode 100644 index 0000000000000..d7d46563fe5dc --- /dev/null +++ b/tests/lora/conftest.py @@ -0,0 +1,140 @@ +import gc +import tempfile +from collections import OrderedDict +from unittest.mock import patch, MagicMock + +import pytest +import ray +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download + +import vllm +from vllm.config import LoRAConfig +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parallel_utils.parallel_state import ( + destroy_model_parallel, initialize_model_parallel) + + +def cleanup(): + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + ray.shutdown() + + +@pytest.fixture(autouse=True) +def cleanup_fixture(): + yield + cleanup() + + +@pytest.fixture +def dist_init(): + if not torch.distributed.is_initialized(): + temp_file = tempfile.mkstemp()[1] + torch.distributed.init_process_group( + backend="nccl", + world_size=1, + rank=0, + init_method=f"file://{temp_file}", + ) + torch.distributed.all_reduce(torch.zeros(1).cuda()) + initialize_model_parallel(1, 1) + yield + cleanup() + + +@pytest.fixture +def dist_init_torch_only(): + if torch.distributed.is_initialized(): + return + temp_file = tempfile.mkstemp()[1] + torch.distributed.init_process_group( + backend="nccl", + world_size=1, + rank=0, + init_method=f"file://{temp_file}", + ) + + +@pytest.fixture +def dummy_model() -> nn.Module: + model = nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ])), + ), + ("act2", nn.ReLU()), + ("output", ColumnParallelLinear(50, 10)), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("sampler", Sampler(512)) + ])) + model.config = MagicMock() + return model + + +@pytest.fixture +def dummy_model_gate_up() -> nn.Module: + model = nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ])), + ), + ("act2", nn.ReLU()), + ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("sampler", Sampler(512)) + ])) + model.config = MagicMock() + return model + + +@pytest.fixture(scope="session") +def sql_lora_files(): + return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + + +@pytest.fixture +def llama_2_7b_engine_extra_embeddings() -> nn.Module: + cleanup() + get_model_old = get_model + + def get_model_patched(model_config, lora_config=None): + return get_model_old(model_config, + LoRAConfig(max_loras=4, max_lora_rank=8)) + + with patch("vllm.worker.model_runner.get_model", get_model_patched): + engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) + yield engine.llm_engine + del engine + cleanup() + + +@pytest.fixture +def llama_2_7b_model_extra_embeddings( + llama_2_7b_engine_extra_embeddings) -> nn.Module: + yield llama_2_7b_engine_extra_embeddings.workers[0].model_runner.model diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py new file mode 100644 index 0000000000000..71c671132205a --- /dev/null +++ b/tests/lora/test_layers.py @@ -0,0 +1,709 @@ +import pytest +import random +from copy import deepcopy +from dataclasses import dataclass +from typing import List, Optional, Dict, Tuple + +import torch +import torch.nn.functional as F + +from vllm.lora.layers import ( + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + QKVParallelLinearWithLora, + VocabParallelEmbeddingWithLoRA, + RowParallelLinearWithLoRA, + SamplerWithLoRA, + LoRAMapping, + BaseLayerWithLoRA, +) +from vllm.lora.models import LoRALayerWeights, convert_mapping, PackedLoRALayerWeights +from vllm.config import LoRAConfig +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, + QKVParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead +from vllm.model_executor.utils import set_random_seed + +from .utils import DummyLoRAManager + +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.float32: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} + + +def get_random_id_to_index(num_loras: int, + num_slots: int, + log: bool = True) -> List[Optional[int]]: + """Creates a random lora_id_to_index mapping. + + Args: + num_loras: The number of active loras in the mapping. + num_slots: The number of slots in the mapping. Must be larger + than num_loras. + log: Whether to log the output. + """ + + if num_loras > num_slots: + raise ValueError( + f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " + "num_loras must be less than or equal to num_slots.") + + slots: List[Optional[int]] = [None] * num_slots + random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() + for lora_id, slot_idx in enumerate(random_slot_selections, start=1): + slots[slot_idx] = lora_id + + if log: + print(f"Created lora_id_to_index mapping: {slots}.") + + return slots + + +def populate_loras( + id_to_index: List[Optional[int]], + layer: BaseLayerWithLoRA, + layer_weights: torch.Tensor, + generate_embeddings_tensor: int = 0, + repeats: int = 1, +) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]: + """This method populates the lora layers with lora weights. + + Args: + id_to_index: a list of lora ids. The index of the lora id + represents which memory slot the lora matrices are + stored in. A None value indicates a free slot. + layer: the LoRAlayer to populate. + layer_weights: the PyTorch tensor containing the layer's + weights. + generate_embeddings_tensor: whether to generate an + embeddings tensor for each LoRA. + repeats: must only be set for column parallel packed + layers. Indicates the number of loras to compose + together to create a single lora layer. + """ + + # Dictionary that maps the lora ID to the + # corresponding lora weights. + lora_dict: Dict[int, LoRALayerWeights] = dict() + + # Dictionary that maps the lora ID to the + # corresponding subloras. Only useful when + # repeats > 1. + sublora_dict: Dict[int, List[LoRALayerWeights]] = dict() + + for slot_idx, lora_id in enumerate(id_to_index): + if lora_id is not None: + subloras = [] + sublora_len = layer_weights.shape[0] // repeats + for i in range(repeats): + sublora = DummyLoRAManager().init_random_lora( + module_name=f"fake_{i}", + weight=layer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) + sublora.lora_b = sublora.lora_b[:, (sublora_len * + i):(sublora_len * (i + 1))] + sublora.optimize() + subloras.append(sublora) + + lora = PackedLoRALayerWeights.pack( + subloras) if repeats > 1 else subloras[0] + + layer.set_lora( + slot_idx, + lora_a=lora.lora_a, + lora_b=lora.lora_b, + embeddings_tensor=lora.embeddings_tensor, + ) + + lora_dict[lora_id] = lora + sublora_dict[lora_id] = subloras + + return lora_dict, sublora_dict + + +def create_random_inputs( + active_lora_ids: List[int], + num_inputs: int, + input_size: Tuple[int, ...], + input_range: Tuple[float, float], + input_type: torch.dtype = torch.int, +) -> Tuple[List[torch.Tensor], List[int], List[int]]: + """Creates random inputs. + + Args: + active_lora_ids: lora IDs of active lora weights. + num_inputs: the number of inputs to create. + input_size: the size of each individual input. + input_range: the range of values to include in the input. + input_range[0] <= possible input values < input_range[1] + input_type: the type of values in the input. + """ + + low, high = input_range + + inputs, index_mapping, prompt_mapping = [], [], [] + for _ in range(num_inputs): + if input_type == torch.int: + inputs.append( + torch.randint(low=int(low), + high=int(high), + size=input_size, + device="cuda")) + else: + inputs.append( + torch.rand(size=input_size, dtype=input_type, device="cuda") * + high + low) + + lora_id = random.choice(active_lora_ids) + index_mapping += [lora_id] * input_size[0] + prompt_mapping += [lora_id] + + return inputs, index_mapping, prompt_mapping + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_embeddings(dist_init, num_loras) -> None: + + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(512, 256) + embedding.weight.data = torch.rand_like(embedding.weight.data) + embedding.weight.data[512:, :] = 0 + lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return embedding, lora_embedding + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + embedding, lora_embedding = create_random_embedding_layer() + + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=embedding.weight.T, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info) + + lora_result = lora_embedding(torch.cat(inputs)) + + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = embedding(input_) + after_a = F.embedding( + input_, + lora.lora_a, + ) + result += (after_a @ lora.lora_b) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_embedding.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info, ) + + lora_result = lora_embedding(torch.cat(inputs)) + expected_result = embedding(torch.cat(inputs)) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +# @pytest.mark.skip(reason="Fails when loras are in any slot other than the first.") +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None: + + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(512, 256) + embedding_data = torch.rand_like(embedding.weight.data) + embedding.weight.data = embedding_data + embedding.weight.data[512:, :] = 0 + expanded_embedding = VocabParallelEmbedding( + 512 + lora_config.lora_extra_vocab_size * max_loras, + 256, + org_num_embeddings=512) + expanded_embedding.weight.data[:512, :] = embedding_data + # We need to deepcopy the embedding as it will be modifed + # in place + lora_embedding = VocabParallelEmbeddingWithLoRA( + deepcopy(expanded_embedding)) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return expanded_embedding, lora_embedding + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + expanded_embedding, lora_embedding = create_random_embedding_layer() + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=torch.zeros( + (256, 512 + lora_config.lora_extra_vocab_size)), + generate_embeddings_tensor=256, + ) + + # All embeddings tensors have the same shape. + embeddings_tensors = [ + lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) + ] + embeddings_tensor_len = embeddings_tensors[0].shape[0] + + # Add empty embeddings_tensors for unoccupied lora slots. + for _ in range(max_loras - len(embeddings_tensors)): + embeddings_tensors.append( + torch.zeros(embeddings_tensors[0].shape, device="cuda")) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + original_inputs = deepcopy(inputs) + + # Force some of the inputs to be in the extended embeddings range + # to guarantee that their behavior is tested. + for input_, original_input_, lora_id in zip(inputs, original_inputs, + prompt_mapping): + embedding_id = lora_id - 1 + input_[-1] = 512 + (embedding_id * embeddings_tensor_len) + original_input_[-1] = 512 + input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1) + original_input_[-2] = 512 + embeddings_tensor_len - 1 + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info, ) + + expanded_embedding.weight[512:512 + + (embeddings_tensor_len * + max_loras)] = torch.cat(embeddings_tensors) + + lora_result = lora_embedding(torch.cat(original_inputs)) + + expected_results = [] + for input_, original_input_, lora_id in zip(inputs, original_inputs, + prompt_mapping): + lora = lora_dict[lora_id] + result = expanded_embedding(input_) + after_a = F.embedding( + original_input_, + lora.lora_a, + ) + result += (after_a @ lora.lora_b) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_embedding.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + original_inputs = deepcopy(inputs) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info, ) + + lora_result = lora_embedding(torch.cat(original_inputs)) + expected_result = expanded_embedding(torch.cat(inputs)) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_lm_head_sampler(dist_init, num_loras) -> None: + + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_sampler_layer(): + linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, + 1024, 32000) + linear.weight.data = torch.rand_like(linear.weight.data) + linear.weight.data[:, 32000:] = 0 + sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000) + lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype, + linear.weight.device) + lora_sampler.create_lora_weights(max_loras, lora_config) + + return linear, sampler, lora_sampler + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, sampler, lora_sampler = create_random_sampler_layer() + + # NOTE: all the generated loras share the same embeddings tensor. + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_sampler, + layer_weights=linear.weight, + generate_embeddings_tensor=1024, + ) + embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor + embeddings_tensor_len = embeddings_tensor.shape[0] + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=8 * num_loras, # * 3, + input_size=(1, 1024), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + input_ = torch.rand(20, 1024, device="cuda") + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 32000, + lora_config.lora_extra_vocab_size, + ) + lora_sampler.set_mapping(*mapping_info, ) + + lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), + embedding=linear.weight, + embedding_bias=None) + + original_weight = linear.weight.clone() + + linear.weight[sampler.org_vocab_size:sampler.org_vocab_size + + embeddings_tensor_len] = embeddings_tensor + + sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = sampler._get_logits(hidden_states=input_, + embedding=linear.weight, + embedding_bias=None) + result[:, 32000 + embeddings_tensor_len:] = float("-inf") + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + sampler.org_vocab_size = 32000 + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_sampler.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=8 * num_loras * 3, + input_size=(1, 1024), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 32000, + lora_config.lora_extra_vocab_size) + lora_sampler.set_mapping(*mapping_info, ) + + lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None)[:, :32000] + expected_result = sampler._get_logits(hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("orientation", ["row", "column"]) +def test_linear_parallel(dist_init, num_loras, orientation) -> None: + + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_linear_parallel_layer(): + if orientation == "row": + linear = RowParallelLinear(4096, 4096, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = RowParallelLinearWithLoRA(linear) + else: + linear = ColumnParallelLinear(4096, 4096, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = ColumnParallelLinearWithLoRA(linear) + lora_linear.create_lora_weights(max_loras, lora_config) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, lora_linear = create_random_linear_parallel_layer() + + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + lora_linear.set_mapping(*mapping_info, ) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = linear(input_)[0] + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_linear.set_mapping(*mapping_info, ) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("repeats", [2, 3]) +def test_column_parallel_packed(dist_init, num_loras, repeats) -> None: + + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_column_parallel_packed_layer(): + if repeats == 2: + linear = MergedColumnParallelLinear(4096, [4096] * repeats, + bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = MergedColumnParallelLinearWithLoRA(linear) + else: + linear = QKVParallelLinear(4096, 64, 32, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = QKVParallelLinearWithLora(linear) + + @dataclass + class FakeConfig: + hidden_size = 4096 + num_key_value_heads = 32 + num_attention_heads = 32 + + lora_linear.create_lora_weights(max_loras, + lora_config, + model_config=FakeConfig()) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + + linear, lora_linear = create_column_parallel_packed_layer() + + lora_dict, sublora_dict = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + repeats=repeats, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + lora_linear.set_mapping(*mapping_info) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + result = linear(input_)[0] + subloras = sublora_dict[lora_id] + for i, sublora in enumerate(subloras): + result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * ( + i + 1 + )] += input_ @ sublora.lora_a @ sublora.lora_b * sublora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + lora_linear.set_mapping(*mapping_info) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py new file mode 100644 index 0000000000000..4760c5cc1e950 --- /dev/null +++ b/tests/lora/test_llama.py @@ -0,0 +1,145 @@ +import pytest +import ray +import torch + +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" + + +def do_sample(llm, lora_path: str, lora_id: int): + prompts = [ + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]" + ] + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=256, + stop=["[/assistant]"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("tp_size", [1, 2, 4]) +def test_llama_lora(sql_lora_files, tp_size): + if torch.cuda.device_count() < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + + llm = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=tp_size, + worker_use_ray=True) + + expected_no_lora_output = [ + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", + "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", + ] + expected_lora_output = [ + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", + " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", + " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", + " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", + " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", + " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " + ] + + print("lora adapter created") + assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output + + print("lora 1") + assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output + + print("no lora") + assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output + + print("lora 2") + assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output + + print("removing lora") + + +def test_llama_tensor_parallel_equality(sql_lora_files): + if torch.cuda.device_count() < 4: + pytest.skip(f"Not enough GPUs for tensor parallelism {4}") + + llm_tp1 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=1, + worker_use_ray=True) + output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1) + + del llm_tp1 + ray.shutdown() + + llm_tp2 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=2, + worker_use_ray=True) + output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1) + + del llm_tp2 + ray.shutdown() + + assert output_tp1 == output_tp2 + + llm_tp4 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=4, + worker_use_ray=True) + output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1) + + del llm_tp4 + ray.shutdown() + + assert output_tp1 == output_tp4 + + +def test_llama_lora_warmup(sql_lora_files): + """Test that the LLM initialization works with a warmup LORA path and is more conservative""" + + @ray.remote(num_gpus=1) + def get_num_gpu_blocks_lora(): + llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16) + num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks + return num_gpu_blocks_lora_warmup + + @ray.remote(num_gpus=1) + def get_num_gpu_blocks_no_lora(): + llm = vllm.LLM(MODEL_PATH, max_num_seqs=16) + num_gpu_blocks_no_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks + return num_gpu_blocks_no_lora_warmup + + num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote()) + num_gpu_blocks_no_lora_warmup = ray.get( + get_num_gpu_blocks_no_lora.remote()) + assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, ( + "The warmup with lora should be more" + " conservative than without lora, therefore the number of memory blocks for the KV cache should be " + "less when using lora than when not using lora") diff --git a/tests/lora/test_lora.py b/tests/lora/test_lora.py new file mode 100644 index 0000000000000..3415d36b7e341 --- /dev/null +++ b/tests/lora/test_lora.py @@ -0,0 +1,224 @@ +import pytest +import torch + +from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice + +from .utils import DummyLoRAManager + +TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4] +QKV_TENSOR_SIZES = [ + (8192, 1024, 1024), + (8192 // 8, 1024 // 8, 1024 // 8), + (4096, 4096, 4096), + (4096 // 2, 4096 // 2, 4096 // 2), +] +BATCH_SIZES = [8, 32, 256] +RANKS = [8] +DTYPES = [torch.float16] +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora(m, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m, n], device="cuda", dtype=dtype) + + manager.init_random_lora(module_name, weight, rank=rank) + lora = manager.get_module_lora(module_name) + + input = torch.rand(k, n, device="cuda", dtype=dtype) + expected = input @ lora.lora_a @ lora.lora_b * lora.scaling + + lora_a_stack = torch.zeros(8, + 1, + lora.lora_a.shape[1], + lora.lora_a.shape[0], + device="cuda", + dtype=dtype) + lora_b_stack = torch.zeros(8, + 1, + lora.lora_b.shape[1], + lora.lora_b.shape[0], + device="cuda", + dtype=dtype) + for i in range(lora_a_stack.shape[0]): + lora_a_stack[i][0] = lora.lora_a.T + lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T + + output = torch.zeros(k, m, device="cuda", dtype=dtype) + _apply_lora( + input, lora_a_stack, lora_b_stack, + torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="cuda"), + output) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora(input, lora_a_stack, lora_b_stack, + torch.full((len(input), ), -1, device="cuda"), output) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: + if m % 2 != 0: + pytest.skip("m must be divisible by 2") + if m // 2 not in TENSOR_SIZES: + pytest.skip("m//2 must be in TENSOR_SIZES") + + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m // 2, n], device="cuda", dtype=dtype) + + manager.init_random_lora(module_name + "1", weight, rank=rank) + lora_1 = manager.get_module_lora(module_name + "1") + manager.init_random_lora(module_name + "2", weight, rank=rank) + lora_2 = manager.get_module_lora(module_name + "2") + + input = torch.rand(k, n, device="cuda", dtype=dtype) + expected = torch.cat([ + input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling, + input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(8, + 1, + lora_1.lora_a.shape[1], + lora_1.lora_a.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(8, + 1, + lora_1.lora_b.shape[1], + lora_1.lora_b.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + for i in range(lora_a_stacks[0].shape[0]): + lora_a_stacks[0][i][0] = lora_1.lora_a.T + lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T + lora_a_stacks[1][i][0] = lora_2.lora_a.T + lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T + + output = torch.zeros(k, m, device="cuda", dtype=dtype) + _apply_lora_packed_nslice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="cuda"), output, (m // 2, m // 2)) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="cuda"), + output, (m // 2, m // 2)) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight_q = torch.empty(qkv[0], n, device="cuda", dtype=dtype) + weight_kv = torch.empty(qkv[1], n, device="cuda", dtype=dtype) + + manager.init_random_lora(module_name + "q", weight_q, rank=rank) + lora_q = manager.get_module_lora(module_name + "q") + manager.init_random_lora(module_name + "k", weight_kv, rank=rank) + lora_k = manager.get_module_lora(module_name + "k") + manager.init_random_lora(module_name + "v", weight_kv, rank=rank) + lora_v = manager.get_module_lora(module_name + "v") + + input = torch.rand(k, n, device="cuda", dtype=dtype) + expected = torch.cat([ + input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling, + input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling, + input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(8, + 1, + lora_q.lora_a.shape[1], + lora_q.lora_a.shape[0], + device="cuda", + dtype=dtype) + ] + [ + torch.zeros(8, + 1, + lora_k.lora_a.shape[1], + lora_k.lora_a.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(8, + 1, + lora_q.lora_b.shape[1], + lora_q.lora_b.shape[0], + device="cuda", + dtype=dtype) + ] + [ + torch.zeros(8, + 1, + lora_k.lora_b.shape[1], + lora_k.lora_b.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + for i in range(lora_a_stacks[0].shape[0]): + lora_a_stacks[0][i][0] = lora_q.lora_a.T + lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T + lora_a_stacks[1][i][0] = lora_k.lora_a.T + lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T + lora_a_stacks[2][i][0] = lora_v.lora_a.T + lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T + + output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype) + _apply_lora_packed_nslice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="cuda"), output, (qkv[0], qkv[1], qkv[2])) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="cuda"), + output, (qkv[0], qkv[1], qkv[2])) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py new file mode 100644 index 0000000000000..78a4a5bc5ecd2 --- /dev/null +++ b/tests/lora/test_lora_manager.py @@ -0,0 +1,475 @@ +import os +from typing import List + +import pytest +import torch +from safetensors.torch import load_file +from torch import nn + +from vllm.config import LoRAConfig +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + RowParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA) +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, LoRAMapping) +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, + WorkerLoRAManager) +from vllm.model_executor.layers.linear import RowParallelLinear + + +def test_from_lora_tensors(sql_lora_files): + tensors = load_file( + os.path.join(sql_lora_files, "adapter_model.safetensors")) + new_embeddings = load_file( + os.path.join(sql_lora_files, "new_embeddings.safetensors")) + lora_model = LoRAModel.from_lora_tensors(1, + 8, + 16, + tensors, + "cuda", + embeddings=new_embeddings) + for module_name, lora in lora_model.loras.items(): + assert lora.module_name == module_name + assert lora.rank == 8 + assert lora.lora_alpha == 16 + assert lora.lora_a is not None + assert lora.lora_b is not None + assert (lora.lora_a.shape[1] == lora.lora_b.shape[0] + ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" + assert lora.lora_a.shape[1] == 8 + embeddings_module = next( + (k for k in EMBEDDING_MODULES if k in module_name), None) + if embeddings_module: + assert torch.equal( + lora.embeddings_tensor, + new_embeddings[EMBEDDING_MODULES[embeddings_module]].to( + device=lora.embeddings_tensor.device)) + else: + assert lora.embeddings_tensor is None + + +def create_lora(lora_id: int, model: nn.Module, + sub_modules: List[str]) -> LoRAModel: + loras = {} + for name in sub_modules: + w = model.get_submodule(name).weight + loras[name] = LoRALayerWeights( + name, + 8, + 16, + torch.rand([w.shape[1], 8], device="cuda"), + torch.rand([8, w.shape[0]], device="cuda"), + ) + return LoRAModel(lora_id, 8, loras) + + +def create_packed_lora( + lora_id: int, + model: nn.Module, + module_name, + replaced_module_names, + empty_replaced_module_name=None, +) -> LoRAModel: + w = model.get_submodule(module_name).weight + loras = {} + for replaced_module_name in replaced_module_names: + if replaced_module_name == empty_replaced_module_name: + continue + loras[replaced_module_name] = LoRALayerWeights( + replaced_module_name, + 8, + 16, + torch.rand([w.shape[1], 8], device="cuda"), + torch.rand([8, w.shape[0] // len(replaced_module_names)], + device="cuda"), + ) + return LoRAModel(lora_id, 8, loras) + + +def test_replace_submodules(dist_init, dummy_model): + model = dummy_model + manager = LoRAModelManager(model, + 1, + 1, + 1, + LoRAConfig(max_lora_rank=8, + max_cpu_loras=8, + max_loras=8), + lora_target_modules=["dense1", "layer1.dense2"]) + model = manager.model + + assert isinstance(model.get_submodule("dense1"), + ColumnParallelLinearWithLoRA) + assert isinstance(model.get_submodule("layer1.dense1"), + ColumnParallelLinearWithLoRA) + assert isinstance(model.get_submodule("dense2"), RowParallelLinear) + assert isinstance(model.get_submodule("layer1.dense2"), + RowParallelLinearWithLoRA) + + +def test_lora_model_manager(dist_init, dummy_model): + model = dummy_model + model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) + manager = LoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), + lora_target_modules=["dense1", "dense2", "lm_head"]) + assert all(x is None for x in manager.lora_index_to_id) + assert manager.add_lora(model_lora1) + assert manager.activate_lora(1) + assert manager.lora_index_to_id[0] == 1 + assert not manager.add_lora(model_lora1) + assert not manager.activate_lora(1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 + assert not manager.add_lora(model_lora2) + assert not manager.activate_lora(2) + assert manager.add_lora(model_lora3) + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 + with pytest.raises(ValueError): + assert manager.activate_lora(3) + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 + assert manager.remove_lora(model_lora2.id) + assert manager.lora_index_to_id[1] is None + assert not manager.remove_lora(model_lora2.id) + assert manager.remove_lora(model_lora1.id) + assert not manager.remove_lora(model_lora1.id) + assert manager.add_lora(model_lora1) + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] is None + assert manager.add_lora(model_lora2) + assert manager.activate_lora(3) + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] is None + assert manager.activate_lora(2) + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 2 + + +def test_lora_lru_cache_model_manager(dist_init, dummy_model): + model = dummy_model + model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) + manager = LRUCacheLoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), + lora_target_modules=["dense1", "dense2", "lm_head"]) + assert all(x is None for x in manager.lora_index_to_id) + assert manager.add_lora(model_lora1) + assert manager.activate_lora(1) + assert manager.lora_index_to_id[0] == 1 + assert not manager.add_lora(model_lora1) + assert not manager.activate_lora(1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 + assert not manager.add_lora(model_lora2) + assert not manager.activate_lora(2) + assert manager.add_lora(model_lora3) + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 + assert manager.activate_lora(3) + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 2 + assert manager.remove_lora(model_lora2.id) + assert manager.lora_index_to_id[1] is None + assert not manager.remove_lora(model_lora2.id) + assert manager.remove_lora(model_lora1.id) + assert not manager.remove_lora(model_lora1.id) + assert manager.add_lora(model_lora1) + assert manager.activate_lora(1) + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 1 + assert manager.add_lora(model_lora2) + assert manager.deactivate_lora(3) + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] == 1 + assert manager.activate_lora(2) + assert manager.lora_index_to_id[0] == 2 + assert manager.lora_index_to_id[1] == 1 + assert manager.activate_lora(3) + assert manager.lora_index_to_id[0] == 2 + assert manager.lora_index_to_id[1] == 3 + + +def test_lru_lora_model_manager(dist_init, dummy_model): + # This tests just the LRU cache functionality, everything else is + # tested in test_lora_model_manager + model = dummy_model + model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) + model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"]) + manager = LRUCacheLoRAModelManager( + model, 2, 2, 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2), + ["dense1", "dense2", "lm_head"]) + + assert all(x is None for x in manager.lora_index_to_id) + + # Add up to capacity + assert manager.add_lora(model_lora1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(1) + assert manager.activate_lora(2) + + assert set(manager.list_loras()) == {1, 2} + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 + + # Add over capacity + assert manager.add_lora(model_lora3) + assert manager.add_lora(model_lora4) + assert manager.activate_lora(3) + assert manager.activate_lora(4) + + assert set(manager.list_loras()) == {3, 4} + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 4 + + # Add 3 again to move it to the top and then add 2 + # should return false since it's in already + assert not manager.add_lora(model_lora3) + assert not manager.activate_lora(3) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + + assert set(manager.list_loras()) == {3, 2} + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 2 + + # Remove manually + assert manager.remove_lora(3) + assert not manager.remove_lora(3) + + assert set(manager.list_loras()) == {2} + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] == 2 + + assert manager.add_lora(model_lora3) + assert manager.activate_lora(3) + assert manager.add_lora(model_lora4) + assert manager.activate_lora(4) + + assert set(manager.list_loras()) == {3, 4} + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 4 + + assert manager.remove_oldest_lora() + assert set(manager.list_loras()) == {4} + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] == 4 + + assert manager.remove_oldest_lora() + assert set(manager.list_loras()) == set() + assert all(x is None for x in manager.lora_index_to_id) + + assert not manager.remove_oldest_lora() + assert set(manager.list_loras()) == set() + assert all(x is None for x in manager.lora_index_to_id) + + +def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): + lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + worker_lora_manager = LRUCacheWorkerLoRAManager( + 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, + torch.device("cuda")) + worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) + + mapping = LoRAMapping([], []) + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("3", 3, sql_lora_files), + LoRARequest("4", 4, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 3, 4} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 3 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files), + LoRARequest("5", 5, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 4, 5} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 4, 5} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + + worker_lora_manager.set_active_loras([ + LoRARequest("6", 6, sql_lora_files), + LoRARequest("7", 7, sql_lora_files), + LoRARequest("8", 8, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 6, 7, 8} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 7 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 8 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 6 + + # Over capacity + with pytest.raises(RuntimeError): + worker_lora_manager.set_active_loras([ + LoRARequest("10", 10, sql_lora_files), + LoRARequest("11", 11, sql_lora_files), + LoRARequest("12", 12, sql_lora_files), + LoRARequest("13", 13, sql_lora_files), + LoRARequest("14", 14, sql_lora_files) + ], mapping) + + +def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): + # Should remove every LoRA not specified in the request. + lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + worker_lora_manager = WorkerLoRAManager( + 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, + torch.device("cuda")) + worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) + + mapping = LoRAMapping([], []) + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("3", 3, sql_lora_files), + LoRARequest("4", 4, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 3, 4} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 3 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 4 + + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files), + LoRARequest("5", 5, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 5} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 + + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] is None + assert worker_lora_manager._lora_manager.lora_index_to_id[2] is None + + worker_lora_manager.set_active_loras([ + LoRARequest("6", 6, sql_lora_files), + LoRARequest("7", 7, sql_lora_files), + LoRARequest("8", 8, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {6, 7, 8} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 8 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 6 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 7 + + # Over capacity + with pytest.raises(RuntimeError): + worker_lora_manager.set_active_loras([ + LoRARequest("10", 10, sql_lora_files), + LoRARequest("11", 11, sql_lora_files), + LoRARequest("12", 12, sql_lora_files), + LoRARequest("13", 13, sql_lora_files), + LoRARequest("14", 14, sql_lora_files) + ], mapping) + + +def test_packed_loras(dist_init, dummy_model_gate_up): + model = dummy_model_gate_up + model_lora = create_packed_lora( + 1, + model, + module_name="gate_up_proj", + replaced_module_names=["gate_proj", "up_proj"]) + model_lora1 = create_packed_lora( + 2, + model, + module_name="gate_up_proj", + replaced_module_names=["gate_proj", "up_proj"], + empty_replaced_module_name="gate_proj", + ) + + manager = LoRAModelManager( + model, 2, 2, 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2), + ["gate_up_proj"]) + model = manager.model + + assert isinstance(model.get_submodule("gate_up_proj"), + MergedColumnParallelLinearWithLoRA) + assert manager.add_lora(model_lora) + assert manager.add_lora(model_lora1) + + packed_lora = model_lora.get_lora("gate_up_proj") + assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights) + + assert torch.allclose(packed_lora.lora_a[0], + model_lora.get_lora("gate_proj").lora_a) + assert torch.allclose(packed_lora.lora_b[0], + model_lora.get_lora("gate_proj").lora_b) + assert torch.allclose(packed_lora.lora_a[1], + model_lora.get_lora("up_proj").lora_a) + assert torch.allclose(packed_lora.lora_b[1], + model_lora.get_lora("up_proj").lora_b) + + packed_lora1 = model_lora1.get_lora("gate_up_proj") + assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights) + + assert packed_lora1.lora_a[0] is None + assert packed_lora1.lora_b[0] is None + assert torch.allclose(packed_lora1.lora_a[1], + model_lora1.get_lora("up_proj").lora_a) + assert torch.allclose(packed_lora1.lora_b[1], + model_lora1.get_lora("up_proj").lora_b) diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py new file mode 100644 index 0000000000000..f603b06cdb565 --- /dev/null +++ b/tests/lora/test_punica.py @@ -0,0 +1,175 @@ +# Based on code from https://github.com/punica-ai/punica + +import pytest +import torch + +import vllm.lora.punica as punica + + +def assert_close(a, b): + rtol, atol = { + torch.float16: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), + torch.float32: (None, None), + }[a.dtype] + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + +def _lora_ref_impl( + y_final: torch.Tensor, + x: torch.Tensor, + wa_T_all: torch.Tensor, + wb_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, +): + y_stage_1 = torch.empty( + (x.size(0), wa_T_all.size(-2)), + dtype=torch.float32, + device=x.device, + ) + bs = x.shape[0] + s = torch.tensor(scale, dtype=torch.float32, device=x.device) + for i, lora_idx in zip(range(bs), indicies.cpu().tolist()): + xi = x[i].unsqueeze(0).to(torch.float32) + wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) + wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) + + tmp = xi @ wa + y_stage_1[i] = tmp.squeeze(0) + y_final[i] += (tmp @ wb).squeeze(0) * s + return y_final, y_stage_1 + + +H1 = H2 = [ + 128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120, + 5504, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 32000, 32256, + 32512, 32768, 33024 +] +SEED = [0xabcdabcd987] + + +@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) +@pytest.mark.parametrize("h1", H1) +@pytest.mark.parametrize("h2", H2) +@pytest.mark.parametrize("seed", SEED) +@torch.inference_mode() +def test_lora_correctness(dtype_str, h1, h2, seed): + torch.manual_seed(seed) + num_loras = 4 + num_layers = 1 + r = 8 + bs = 32 + scale = 0.123 + dtype = getattr(torch, dtype_str) + device = torch.device("cuda") + + wa_T_all = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wb_T_all = torch.randn(num_loras, + num_layers, + h2, + r, + dtype=dtype, + device=device) + indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + + for layer_idx in range(num_layers): + x = torch.randn(bs, h1, dtype=dtype, device=device) + y = torch.randn(bs, h2, dtype=dtype, device=device) + + y_ref = y.clone() + _lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale) + + y_our = y.clone() + punica.add_lora(y_our, x, wa_T_all, wb_T_all, indices, layer_idx, + scale) + + assert_close(y_ref, y_our) + + +@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) +@pytest.mark.parametrize("h1", H1) +@pytest.mark.parametrize("h2", H2) +@pytest.mark.parametrize("seed", SEED) +@torch.inference_mode() +def test_lora_correctness_slice(dtype_str, h1, h2, seed): + if h2 % 3 != 0 or h2 // 3 not in H1: + pytest.skip("h2 must be divisible by 3 and in supported shapes") + torch.manual_seed(seed) + num_loras = 4 + num_layers = 1 + r = 8 + bs = 32 + scale = 0.123 + dtype = getattr(torch, dtype_str) + device = torch.device("cuda") + + wa_T_all_0 = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wa_T_all_1 = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wa_T_all_2 = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wb_T_all_0 = torch.randn(num_loras, + num_layers, + h2 // 3, + r, + dtype=dtype, + device=device) + wb_T_all_1 = torch.randn(num_loras, + num_layers, + h2 // 3, + r, + dtype=dtype, + device=device) + wb_T_all_2 = torch.randn(num_loras, + num_layers, + h2 // 3, + r, + dtype=dtype, + device=device) + + indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + + for layer_idx in range(num_layers): + x = torch.randn(bs, h1, dtype=dtype, device=device) + y = torch.randn(bs, h2, dtype=dtype, device=device) + s = h2 // 3 + + y_ref = y.clone() + _lora_ref_impl(y_ref[:, :s], x, wa_T_all_0, wb_T_all_0, indices, + layer_idx, scale) + _lora_ref_impl(y_ref[:, s:s * 2], x, wa_T_all_1, wb_T_all_1, indices, + layer_idx, scale) + _lora_ref_impl(y_ref[:, s * 2:], x, wa_T_all_2, wb_T_all_2, indices, + layer_idx, scale) + + y_our = y.clone() + punica.add_lora_slice(y_our, x, wa_T_all_0, wb_T_all_0, indices, + layer_idx, scale, 0, s) + punica.add_lora_slice(y_our, x, wa_T_all_1, wb_T_all_1, indices, + layer_idx, scale, s, s) + punica.add_lora_slice(y_our, x, wa_T_all_2, wb_T_all_2, indices, + layer_idx, scale, s * 2, s) + + assert_close(y_ref[:, :s], y_our[:, :s]) + assert_close(y_ref[:, s:s * 2], y_our[:, s:s * 2]) + assert_close(y_ref[:, s * 2:], y_our[:, s * 2:]) diff --git a/tests/lora/test_tokenizer.py b/tests/lora/test_tokenizer.py new file mode 100644 index 0000000000000..af0fc41f3fa45 --- /dev/null +++ b/tests/lora/test_tokenizer.py @@ -0,0 +1,69 @@ +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import MultiLoRATokenizer, get_lora_tokenizer + + +@pytest.mark.asyncio +async def test_transformers_tokenizer(): + reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer = MultiLoRATokenizer( + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None, + ) + assert reference_tokenizer.encode("prompt") == tokenizer.encode( + request_id="request_id", prompt="prompt", lora_request=None) + assert reference_tokenizer.encode( + "prompt") == await tokenizer.encode_async(request_id="request_id", + prompt="prompt", + lora_request=None) + assert isinstance(tokenizer.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer.get_lora_tokenizer( + None) == await tokenizer.get_lora_tokenizer_async(None) + + +@pytest.mark.asyncio +async def test_transformers_tokenizer_lora(sql_lora_files): + reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) + tokenizer = MultiLoRATokenizer( + tokenizer_id="gpt2", + enable_lora=True, + max_num_seqs=1, + max_input_length=None, + ) + lora_request = LoRARequest("1", 1, sql_lora_files) + assert reference_tokenizer.encode("prompt") == tokenizer.encode( + request_id="request_id", prompt="prompt", lora_request=lora_request) + assert reference_tokenizer.encode( + "prompt") == await tokenizer.encode_async(request_id="request_id", + prompt="prompt", + lora_request=lora_request) + assert isinstance(tokenizer.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer.get_lora_tokenizer( + None) == await tokenizer.get_lora_tokenizer_async(None) + + assert isinstance(tokenizer.get_lora_tokenizer(lora_request), + PreTrainedTokenizerBase) + assert tokenizer.get_lora_tokenizer( + lora_request) != tokenizer.get_lora_tokenizer(None) + assert tokenizer.get_lora_tokenizer( + lora_request) == await tokenizer.get_lora_tokenizer_async(lora_request) + + +def test_get_lora_tokenizer(sql_lora_files, tmpdir): + lora_request = None + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer + + lora_request = LoRARequest("1", 1, sql_lora_files) + tokenizer = get_lora_tokenizer(lora_request) + assert tokenizer.get_added_vocab() + + lora_request = LoRARequest("1", 1, str(tmpdir)) + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py new file mode 100644 index 0000000000000..2996322f4aa48 --- /dev/null +++ b/tests/lora/test_utils.py @@ -0,0 +1,172 @@ +from collections import OrderedDict + +from torch import nn + +from vllm.utils import LRUCache +from vllm.lora.utils import (parse_fine_tuned_lora_name, replace_submodule) + + +def test_parse_fine_tuned_lora_name(): + fixture = { + ("base_model.model.lm_head.lora_A.weight", "lm_head", True), + ("base_model.model.lm_head.lora_B.weight", "lm_head", False), + ( + "base_model.model.model.embed_tokens.lora_embedding_A", + "model.embed_tokens", + True, + ), + ( + "base_model.model.model.embed_tokens.lora_embedding_B", + "model.embed_tokens", + False, + ), + ( + "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", + "model.layers.9.mlp.down_proj", + True, + ), + ( + "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", + "model.layers.9.mlp.down_proj", + False, + ), + } + for name, module_name, is_lora_a in fixture: + assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name) + + +def test_replace_submodule(): + model = nn.Sequential( + OrderedDict([ + ("dense1", nn.Linear(764, 100)), + ("act1", nn.ReLU()), + ("dense2", nn.Linear(100, 50)), + ( + "seq1", + nn.Sequential( + OrderedDict([ + ("dense1", nn.Linear(100, 10)), + ("dense2", nn.Linear(10, 50)), + ])), + ), + ("act2", nn.ReLU()), + ("output", nn.Linear(50, 10)), + ("outact", nn.Sigmoid()), + ])) + + sigmoid = nn.Sigmoid() + + replace_submodule(model, "act1", sigmoid) + assert dict(model.named_modules())["act1"] == sigmoid + + dense2 = nn.Linear(1, 5) + replace_submodule(model, "seq1.dense2", dense2) + assert dict(model.named_modules())["seq1.dense2"] == dense2 + + +class TestLRUCache(LRUCache): + + def _on_remove(self, key, value): + if not hasattr(self, "_remove_counter"): + self._remove_counter = 0 + self._remove_counter += 1 + + +def test_lru_cache(): + cache = TestLRUCache(3) + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(2, 2) + assert len(cache) == 2 + + cache.put(3, 3) + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache.put(4, 4) + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + assert cache.get(2) == 2 + + cache.put(5, 5) + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + assert cache.pop(5) == 5 + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.get(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.put(6, 6) + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache + + cache.remove_oldest() + assert len(cache) == 2 + assert set(cache.cache) == {2, 6} + assert cache._remove_counter == 4 + + cache.clear() + assert len(cache) == 0 + assert cache._remove_counter == 6 + + cache._remove_counter = 0 + + cache[1] = 1 + assert len(cache) == 1 + + cache[1] = 1 + assert len(cache) == 1 + + cache[2] = 2 + assert len(cache) == 2 + + cache[3] = 3 + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache[4] = 4 + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + assert cache[2] == 2 + + cache[5] = 5 + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + del cache[5] + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache[6] = 6 + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py new file mode 100644 index 0000000000000..126d910f53ab3 --- /dev/null +++ b/tests/lora/test_worker.py @@ -0,0 +1,57 @@ +import os +import random +import tempfile +from unittest.mock import patch + +from vllm.lora.models import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig +from vllm.worker.worker import Worker + + +@patch.dict(os.environ, {"RANK": "0"}) +def test_worker_apply_lora(sql_lora_files): + worker = Worker( + model_config=ModelConfig("meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-2-7b-hf", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None), + parallel_config=ParallelConfig(1, 1, False), + scheduler_config=SchedulerConfig(32, 32, 32, 256), + lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, + max_loras=32), + distributed_init_method=f"file://{tempfile.mkstemp()[1]}", + ) + worker.init_model() + worker.load_model() + + worker.model_runner.set_active_loras([], LoRAMapping([], [])) + assert worker.list_loras() == set() + + n_loras = 32 + lora_requests = [ + LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras) + ] + + worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], [])) + assert worker.list_loras() == { + lora_request.lora_int_id + for lora_request in lora_requests + } + + for i in range(32): + random.seed(i) + iter_lora_requests = random.choices(lora_requests, + k=random.randint(1, n_loras)) + random.shuffle(iter_lora_requests) + iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)] + worker.model_runner.set_active_loras(iter_lora_requests, + LoRAMapping([], [])) + assert worker.list_loras().issuperset( + {lora_request.lora_int_id + for lora_request in iter_lora_requests}) diff --git a/tests/lora/utils.py b/tests/lora/utils.py new file mode 100644 index 0000000000000..280e0f2043e68 --- /dev/null +++ b/tests/lora/utils.py @@ -0,0 +1,88 @@ +from typing import List, Optional + +import torch + +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights + + +class DummyLoRAManager: + + def __init__(self): + super().__init__() + self._loras = {} + + def set_module_lora(self, module_name: str, lora: LoRALayerWeights): + self._loras[module_name] = lora + + def get_module_lora(self, module_name: str) -> Optional[LoRALayerWeights]: + return self._loras.get(module_name, None) + + def init_random_lora(self, + module_name: str, + weight: torch.Tensor, + rank: int = 8, + generate_embeddings_tensor: int = 0): + lora = LoRALayerWeights( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([weight.shape[1], rank], + dtype=weight.dtype, + device="cuda"), + lora_b=torch.rand([rank, weight.shape[0]], + dtype=weight.dtype, + device="cuda"), + ) + if generate_embeddings_tensor: + lora.embeddings_tensor = torch.rand(5, + generate_embeddings_tensor, + dtype=weight.dtype, + device="cuda") + self.set_module_lora(module_name, lora) + + return lora + + def init_lora(self, + module_name: str, + input_dim: int, + output_dim: int, + rank=8, + noop=False, + embeddings_tensor=None): + lora = LoRALayerWeights( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([input_dim, rank], device="cuda"), + lora_b=torch.rand([rank, output_dim], device="cuda"), + embeddings_tensor=embeddings_tensor, + ) + self.set_module_lora(module_name, lora) + return lora + + def reset_lora(self): + self._loras = {} + + def init_packed_lora( + self, + module_name: str, + input_dim: int, + output_dims: List[int], + noop_lora_index: List[int] = None, + rank=8, + ): + base_loras = [] + noop_lora_index = set(noop_lora_index or []) + + for i, out_dim in enumerate(output_dims): + base_lora = self.init_lora( + module_name + "_000_" + str(i), + input_dim, + out_dim, + rank=rank, + noop=i in noop_lora_index, + ) + base_loras.append(base_lora) + packed_lora = PackedLoRALayerWeights.pack(base_loras) + self.set_module_lora(module_name, packed_lora) + return packed_lora diff --git a/vllm/config.py b/vllm/config.py index ff9a1308a5c88..b279ec548c511 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,4 +1,5 @@ -from typing import Optional, Union +from typing import Optional, Union, ClassVar +from dataclasses import dataclass import os import torch @@ -403,6 +404,51 @@ def _verify_args(self) -> None: f"({self.max_num_seqs}).") +@dataclass +class LoRAConfig: + max_lora_rank: int + max_loras: int + max_cpu_loras: Optional[int] = None + lora_dtype: Optional[torch.dtype] = None + lora_extra_vocab_size: int = 256 + # This is a constant. + lora_vocab_padding_size: ClassVar[int] = 256 + + def __post_init__(self): + # Keep this in sync with csrc/punica/bgmv/bgmv_config.h + possible_max_ranks = (8, 16, 32, 64, 128) + possible_lora_extra_vocab_size = (0, 256, 512) + if self.max_lora_rank not in possible_max_ranks: + raise ValueError( + f"max_lora_rank ({self.max_lora_rank}) must be one of " + f"{possible_max_ranks}.") + if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: + raise ValueError( + f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " + f"must be one of {possible_lora_extra_vocab_size}.") + if self.max_loras < 1: + raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") + if self.max_cpu_loras is None: + self.max_cpu_loras = self.max_loras + elif self.max_cpu_loras < self.max_loras: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_num_seqs ({self.max_loras})") + + def verify_with_model_config(self, model_config: ModelConfig): + if self.lora_dtype in (None, "auto"): + self.lora_dtype = model_config.dtype + elif isinstance(self.lora_dtype, str): + self.lora_dtype = getattr(torch, self.lora_dtype) + + def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): + if scheduler_config.max_num_batched_tokens > 65528: + raise ValueError( + "Due to limitations of the custom LoRA CUDA kernel, " + "max_num_batched_tokens must be <= 65528 when " + "LoRA is enabled.") + + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 398585a88fb52..fc5ee185c4045 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,10 +1,11 @@ import enum import time -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union, Set -from vllm.config import CacheConfig, SchedulerConfig +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.block_manager import AllocStatus, BlockSpaceManager from vllm.core.policy import PolicyFactory +from vllm.lora.request import LoRARequest from vllm.logger import init_logger from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) @@ -47,11 +48,23 @@ def __init__( assert not (blocks_to_swap_in and blocks_to_swap_out) self.ignored_seq_groups = ignored_seq_groups + self.num_loras = len(self.lora_requests) + if self.num_loras > 0: + self._sort_by_lora_ids() + def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in and not self.blocks_to_swap_out and not self.blocks_to_copy) + def _sort_by_lora_ids(self) -> bool: + self.scheduled_seq_groups.sort(key=lambda g: ( + g.lora_request.lora_int_id if g.lora_request else 0, g.request_id)) + + @property + def lora_requests(self) -> Set[LoRARequest]: + return {g.lora_request for g in self.scheduled_seq_groups} + class Scheduler: @@ -59,9 +72,14 @@ def __init__( self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, + lora_config: Optional[LoRAConfig], ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config + # Note for LoRA scheduling: the current policy is extremely + # simple and NOT fair. It can lead to starvation of some + # LoRAs. This should be improved in the future. + self.lora_config = lora_config self.prompt_limit = min(self.scheduler_config.max_model_len, self.scheduler_config.max_num_batched_tokens) @@ -83,6 +101,10 @@ def __init__( # Sequence groups in the SWAPPED state. self.swapped: List[SequenceGroup] = [] + @property + def lora_enabled(self) -> bool: + return bool(self.lora_config) + def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. self.waiting.append(seq_group) @@ -131,14 +153,16 @@ def _schedule(self) -> SchedulerOutputs: # requests in the generation phase. num_curr_seqs = sum(seq_group.get_max_num_running_seqs() for seq_group in self.running) + curr_loras = set( + seq_group.lora_int_id + for seq_group in self.running) if self.lora_enabled else None seq_lens: List[int] = [] # Optimization: We do not sort the waiting queue since the preempted # sequence groups are added to the front and the new sequence groups # are added to the back. - while self.waiting: - seq_group = self.waiting[0] - + waiting_indices_to_remove = [] + for i, seq_group in enumerate(self.waiting): waiting_seqs = seq_group.get_seqs( status=SequenceStatus.WAITING) assert len(waiting_seqs) == 1, ( @@ -152,7 +176,7 @@ def _schedule(self) -> SchedulerOutputs: for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) - self.waiting.pop(0) + waiting_indices_to_remove.append(i) continue # If the sequence group cannot be allocated, stop. @@ -166,9 +190,18 @@ def _schedule(self) -> SchedulerOutputs: for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) - self.waiting.pop(0) + waiting_indices_to_remove.append(i) continue + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if lora_int_id > 0 and lora_int_id not in curr_loras and len( + curr_loras) >= self.lora_config.max_loras: + # We don't have a space for another LoRA, so + # we ignore this request for now. + continue + # If the number of batched tokens exceeds the limit, stop. new_seq_lens = seq_lens + [num_prompt_tokens] num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) @@ -188,12 +221,17 @@ def _schedule(self) -> SchedulerOutputs: break seq_lens = new_seq_lens - seq_group = self.waiting.pop(0) + waiting_indices_to_remove.append(i) + if lora_int_id > 0: + curr_loras.add(lora_int_id) self._allocate(seq_group) self.running.append(seq_group) num_curr_seqs += num_new_seqs scheduled.append(seq_group) + for i in reversed(waiting_indices_to_remove): + self.waiting.pop(i) + if scheduled or ignored_seq_groups: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, @@ -241,9 +279,22 @@ def _schedule(self) -> SchedulerOutputs: if not preempted: num_curr_seqs = sum(seq_group.get_max_num_running_seqs() for seq_group in self.running) + curr_loras = set( + seq_group.lora_int_id + for seq_group in self.running) if self.lora_enabled else None + + swapped_indices_to_remove = [] + + for i, seq_group in enumerate(self.swapped): + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if lora_int_id > 0 and lora_int_id not in curr_loras and len( + curr_loras) >= self.lora_config.max_loras: + # We don't have a space for another LoRA, so + # we ignore this request for now. + continue - while self.swapped: - seq_group = self.swapped[0] # If the sequence group cannot be swapped in, stop. if not self.block_manager.can_swap_in(seq_group): break @@ -255,12 +306,17 @@ def _schedule(self) -> SchedulerOutputs: self.scheduler_config.max_num_seqs): break - seq_group = self.swapped.pop(0) + swapped_indices_to_remove.append(i) + if lora_int_id > 0: + curr_loras.add(lora_int_id) self._swap_in(seq_group, blocks_to_swap_in) self._append_slot(seq_group, blocks_to_copy) num_curr_seqs += num_new_seqs self.running.append(seq_group) + for i in reversed(swapped_indices_to_remove): + self.swapped.pop(i) + # Each sequence in the generation phase only takes one token slot. # Therefore, the number of batched tokens is equal to the number of # sequences in the RUNNING state. @@ -301,6 +357,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, + lora_request=seq_group.lora_request, ) seq_group_metadata_list.append(seq_group_metadata) return seq_group_metadata_list, scheduler_outputs diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7e58069e2c22d..090fa95bcac02 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoRAConfig) @dataclass @@ -35,6 +35,12 @@ class EngineArgs: quantization: Optional[str] = None enforce_eager: bool = False max_context_len_to_capture: int = 8192 + enable_lora: bool = False + max_loras: int = 1 + max_lora_rank: int = 16 + lora_extra_vocab_size: int = 256 + lora_dtype = 'auto' + max_cpu_loras: Optional[int] = None def __post_init__(self): if self.tokenizer is None: @@ -202,6 +208,39 @@ def add_cli_args( help='maximum context length covered by CUDA ' 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode.') + # LoRA related configs + parser.add_argument('--enable-lora', + action='store_true', + help='If True, enable handling of LoRA adapters.') + parser.add_argument('--max-loras', + type=int, + default=EngineArgs.max_loras, + help='Max number of LoRAs in a single batch.') + parser.add_argument('--max-lora-rank', + type=int, + default=EngineArgs.max_lora_rank, + help='Max LoRA rank.') + parser.add_argument( + '--lora-extra-vocab-size', + type=int, + default=EngineArgs.lora_extra_vocab_size, + help=('Maximum size of extra vocabulary that can be ' + 'present in a LoRA adapter (added to the base ' + 'model vocabulary).')) + parser.add_argument( + '--lora-dtype', + type=str, + default=EngineArgs.lora_dtype, + choices=['auto', 'float16', 'bfloat16', 'float32'], + help=('Data type for LoRA. If auto, will default to ' + 'base model dtype.')) + parser.add_argument( + '--max-cpu-loras', + type=int, + default=EngineArgs.max_cpu_loras, + help=('Maximum number of LoRAs to store in CPU memory. ' + 'Must be >= than max_num_seqs. ' + 'Defaults to max_num_seqs.')) return parser @classmethod @@ -214,7 +253,8 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': def create_engine_configs( self, - ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: + ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, + Optional[LoRAConfig]]: model_config = ModelConfig(self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.load_format, @@ -234,7 +274,14 @@ def create_engine_configs( self.max_num_seqs, model_config.max_model_len, self.max_paddings) - return model_config, cache_config, parallel_config, scheduler_config + lora_config = LoRAConfig( + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras + and self.max_cpu_loras > 0 else None) if self.enable_lora else None + return model_config, cache_config, parallel_config, scheduler_config, lora_config @dataclass diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 611da51f61931..dcdbf2142e83d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -4,6 +4,7 @@ from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, AsyncIterator) +from vllm.lora.request import LoRARequest from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine @@ -196,6 +197,50 @@ async def step_async(self) -> List[RequestOutput]: return self._process_model_outputs(output, scheduler_outputs) + async def encode_request_async( + self, + request_id: str, # pylint: disable=unused-argument + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + ): + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = await self.tokenizer.encode_async( + request_id=request_id, + prompt=prompt, + lora_request=lora_request) + return prompt_token_ids + + async def add_request_async( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + ) -> None: + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + if arrival_time is None: + arrival_time = time.time() + prompt_token_ids = await self.encode_request_async( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + + return self.add_request( + request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + arrival_time=arrival_time, + lora_request=lora_request, + ) + async def _run_workers_async( self, method: str, @@ -325,7 +370,7 @@ async def engine_step(self) -> bool: if self.engine_use_ray: await self.engine.add_request.remote(**new_request) else: - self.engine.add_request(**new_request) + await self.engine.add_request_async(**new_request) if finished_requests: await self._engine_abort(finished_requests) @@ -364,6 +409,7 @@ async def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, ) -> AsyncStream: if self.log_requests: shortened_prompt = prompt @@ -377,7 +423,8 @@ async def add_request( logger.info(f"Received request {request_id}: " f"prompt: {shortened_prompt!r}, " f"sampling params: {sampling_params}, " - f"prompt token ids: {shortened_token_ids}.") + f"prompt token ids: {shortened_token_ids}, " + f"lora_request: {lora_request}.") if not self.is_running: if self.start_engine_loop: @@ -389,12 +436,22 @@ async def add_request( "error that caused the background loop to stop " "(AsyncEngineDeadError).") + if arrival_time is None: + arrival_time = time.time() + prompt_token_ids = await self.engine.encode_request_async( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + stream = self._request_tracker.add_request( request_id, prompt=prompt, sampling_params=sampling_params, prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + arrival_time=arrival_time, + lora_request=lora_request, + ) return stream @@ -403,7 +460,8 @@ async def generate( prompt: Optional[str], sampling_params: SamplingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -418,6 +476,7 @@ async def generate( request_id: The unique id of the request. prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. + lora_request: LoRA request to use for generation, if any. Yields: The output `RequestOutput` objects from the LLMEngine for the @@ -428,11 +487,14 @@ async def generate( arrival_time = time.monotonic() try: - stream = await self.add_request(request_id, - prompt, - sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + stream = await self.add_request( + request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + lora_request=lora_request, + ) async for request_output in stream: yield request_output diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 43bf9747ee184..1f3c8a7cd7ee5 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -4,8 +4,9 @@ from functools import partial from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union +from vllm.lora.request import LoRARequest from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoRAConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import record_metrics @@ -16,7 +17,7 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, - get_tokenizer) + MultiLoRATokenizer) from vllm.utils import Counter if ray: @@ -66,6 +67,7 @@ def __init__( cache_config: CacheConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], distributed_init_method: str, placement_group: Optional["PlacementGroup"], log_stats: bool, @@ -90,17 +92,13 @@ def __init__( self.model_config = model_config self.cache_config = cache_config + self.lora_config = lora_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.log_stats = log_stats self._verify_args() - self.tokenizer = get_tokenizer( - model_config.tokenizer, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code, - tokenizer_revision=model_config.tokenizer_revision, - revision=model_config.revision) + self._init_tokenizer() self.seq_counter = Counter() # Create the parallel GPU workers. @@ -117,7 +115,7 @@ def __init__( self._init_cache() # Create the scheduler. - self.scheduler = Scheduler(scheduler_config, cache_config) + self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) # Logging. self.last_logging_time = 0.0 @@ -126,6 +124,9 @@ def __init__( # List of (timestamp, num_tokens) self.num_generation_tokens: List[Tuple[float, int]] = [] + def get_tokenizer_for_seq(self, sequence: Sequence): + return self.tokenizer.get_lora_tokenizer(sequence.lora_request) + def _init_workers(self, distributed_init_method: str): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker @@ -141,6 +142,7 @@ def _init_workers(self, distributed_init_method: str): self.scheduler_config, 0, distributed_init_method, + lora_config=self.lora_config, ) self.workers.append(worker) self._run_workers( @@ -154,6 +156,18 @@ def _init_workers(self, distributed_init_method: str): max_parallel_loading_workers, ) + def _init_tokenizer(self, **tokenizer_init_kwargs): + init_kwargs = dict( + enable_lora=bool(self.lora_config), + max_num_seqs=self.scheduler_config.max_num_seqs, + max_input_length=None, + tokenizer_mode=self.model_config.tokenizer_mode, + trust_remote_code=self.model_config.trust_remote_code, + revision=self.model_config.tokenizer_revision) + init_kwargs.update(tokenizer_init_kwargs) + self.tokenizer: MultiLoRATokenizer = MultiLoRATokenizer( + self.model_config.tokenizer, **init_kwargs) + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): # Lazy import the Worker to avoid importing torch.cuda/xformers @@ -191,6 +205,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", scheduler_config, None, None, + lora_config=self.lora_config, )) self._run_workers( "init_model", @@ -206,6 +221,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) def _init_cache(self) -> None: """Profiles the memory usage and initializes the KV cache.""" @@ -265,6 +284,20 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": log_stats=not engine_args.disable_log_stats) return engine + def encode_request( + self, + request_id: str, # pylint: disable=unused-argument + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + ): + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = self.tokenizer.encode(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + return prompt_token_ids + def add_request( self, request_id: str, @@ -272,6 +305,7 @@ def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, ) -> None: """Add a request to the engine's request pool. @@ -289,20 +323,26 @@ def add_request( arrival_time: The arrival time of the request. If None, we use the current monotonic time. """ + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") if arrival_time is None: arrival_time = time.monotonic() - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = self.tokenizer.encode(prompt) + prompt_token_ids = self.encode_request( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) + seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, + lora_request) # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time) + arrival_time, lora_request) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) @@ -341,11 +381,13 @@ def _check_beam_search_early_stopping( current_worst_score = (current_worst_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + eos_token_id=self.get_tokenizer_for_seq( + current_worst_seq).eos_token_id)) if early_stopping is False: highest_attainable_score = (best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + eos_token_id=self.get_tokenizer_for_seq( + best_running_seq).eos_token_id)) else: assert early_stopping == "never" if length_penalty > 0.0: @@ -359,7 +401,8 @@ def _check_beam_search_early_stopping( highest_attainable_score = ( best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.get_tokenizer_for_seq( + best_running_seq).eos_token_id, seq_len=max_possible_length)) else: # Otherwise, beam search will prefer shorter sequences. The @@ -368,7 +411,8 @@ def _check_beam_search_early_stopping( highest_attainable_score = ( best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + eos_token_id=self.get_tokenizer_for_seq( + best_running_seq).eos_token_id)) return current_worst_score >= highest_attainable_score def _process_sequence_group_outputs(self, seq_group: SequenceGroup, @@ -459,7 +503,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Sort the finished sequences by their scores. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), + eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id), reverse=True) for seq, parent, is_new in all_finished_seqs[:beam_width]: if is_new: @@ -487,7 +531,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Sort the running sequences by their scores. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), + eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id), reverse=True) # Check if we can stop the beam search. @@ -665,7 +709,7 @@ def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: """Decodes the new token for a sequence.""" (new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally( - self.tokenizer, + self.get_tokenizer_for_seq(seq), all_input_ids=seq.get_token_ids(), prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, @@ -707,11 +751,28 @@ def _check_stop(self, seq: Sequence, return # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == self.tokenizer.eos_token_id): + if ((not sampling_params.ignore_eos) and seq.get_last_token_id() + == self.get_tokenizer_for_seq(seq).eos_token_id): seq.status = SequenceStatus.FINISHED_STOPPED return + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> List[int]: + return self._run_workers("list_loras") + def _run_workers_in_batch( self, workers, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0700298b03a3d..a335c2e8e2e25 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -3,6 +3,7 @@ from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from vllm.lora.request import LoRARequest from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.outputs import RequestOutput @@ -121,6 +122,7 @@ def generate( sampling_params: Optional[SamplingParams] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -135,6 +137,7 @@ def generate( prompt_token_ids: A list of token IDs for the prompts. If None, we use the tokenizer to convert the prompts to token IDs. use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. Returns: A list of `RequestOutput` objects containing the generated @@ -161,7 +164,10 @@ def generate( prompt = prompts[i] if prompts is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ i] - self._add_request(prompt, sampling_params, token_ids) + self._add_request(prompt, + sampling_params, + token_ids, + lora_request=lora_request) return self._run_engine(use_tqdm) def _add_request( @@ -169,10 +175,14 @@ def _add_request( prompt: Optional[str], sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]], + lora_request: Optional[LoRARequest] = None, ) -> None: request_id = str(next(self.request_counter)) - self.llm_engine.add_request(request_id, prompt, sampling_params, - prompt_token_ids) + self.llm_engine.add_request(request_id, + prompt, + sampling_params, + prompt_token_ids, + lora_request=lora_request) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. diff --git a/vllm/lora/__init__.py b/vllm/lora/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py new file mode 100644 index 0000000000000..5c26ce37bbf8d --- /dev/null +++ b/vllm/lora/layers.py @@ -0,0 +1,968 @@ +# pylint: disable=unused-argument +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.lora.punica import add_lora, add_lora_slice, bgmv +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear, + QKVParallelLinear, + MergedColumnParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim + +if TYPE_CHECKING: + pass + + +def _apply_lora( + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + indices: torch.Tensor, + output: torch.Tensor, +): + """Applies lora to each input. + + This method applies all loras to each input. It uses the + indices vector to determine which lora yields the + correct output. An index of -1 means no lora should be + applied. This method adds the final lora results to the + output. + + Input shapes: + x: (batch_size, hidden_dim) + lora_a_stacked: (num_loras, lora_rank, hidden_dim) + lora_b_stacked: (num_loras, output_dim, lora_rank) + indices: (batch_size) + output: (batch_size, output_dim) + """ + org_output = output + x = x.view(-1, x.shape[-1]) + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) + return output.view_as(org_output) + + +def _apply_lora_packed_nslice( + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], +): + """Applies lora to each input. + + This method applies all loras to each input. It uses the + indices vector to determine which lora yields the + correct output. An index of -1 means no lora should be + applied. This method adds the final lora results to the + output. + + This method is used for layers that are composed of multiple sublayers + (slices) packed together. + + Input shapes: + x: (batch_size, hidden_dim) + lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim) + lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), where n is number of slices + """ + org_output = output + x = x.view(-1, x.shape[-1]) + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + offset_left = 0 + for slice_idx in range(len(output_slices)): + add_lora_slice(output, x, lora_a_stacked[slice_idx], + lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left, + output_slices[slice_idx]) + offset_left += output_slices[slice_idx] + return output.view_as(org_output) + + +@dataclass +class LoRAMapping: + # Per every token in input_ids: + index_mapping: Tuple[int, ...] + # Per sampled token: + prompt_mapping: Tuple[int, ...] + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) + + +class BaseLayerWithLoRA(nn.Module): + + def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig, + model_config: PretrainedConfig) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + """Sets the mapping indices.""" + ... + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + + lora_vocab_start_idx = self.base_layer.org_vocab_size + weights_idx = None + if self.base_layer.vocab_end_index > lora_vocab_start_idx: + # We can start adding lora weights + weights_idx = max( + lora_vocab_start_idx - self.base_layer.vocab_start_index, 0) + self.embeddings_slice = (self.base_layer.vocab_start_index - + self.base_layer.org_vocab_size + + weights_idx, + self.base_layer.vocab_end_index - + self.base_layer.org_vocab_size) + self.embeddings_weights = self.base_layer.weight.data[weights_idx:] + self.embeddings_weights.fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + self.embeddings_indices = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1]].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + embeddings = self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2] + )[self.embeddings_slice[0]:self.embeddings_slice[1]] + self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.embeddings_indices = embeddings_indices + self.indices_len = indices_len + + def forward(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = x > self.base_layer.org_vocab_size - 1 + indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) + full_lora_a_embeddings = F.embedding( + x + indices, + self.lora_a_stacked_2d, + ) + indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) + full_output = self.base_layer.forward( + x.add_(indices * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * + full_lora_a_embeddings.shape[1], -1) + bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + return full_output.view_as(full_output_org) + + +class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_a_stacked = torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + max_loras, + 1, + self.base_layer.weight.shape[0], + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + self.output_dim = self.lora_b_stacked.shape[1] + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.indices_len = indices_len + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + ) + return output + + def forward(self, input_): + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply_weights(input_, bias) + if self.base_layer.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @property + def linear_weights(self): + return self.base_layer.linear_weights + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (eg. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__(self, base_layer: MergedColumnParallelLinear) -> None: + super().__init__(base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + n_slices = 2 + if not (len(self.base_layer.output_sizes) == n_slices + and self.base_layer.output_sizes[0] + == self.base_layer.output_sizes[1]): + raise ValueError( + "LoRAColumnParallelLinear2Slice requires 2 slices with " + "the same size.") + self.tp_size = get_tensor_model_parallel_world_size() + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) for _ in range(n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + self.base_layer.weight.shape[0] // 2, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) for _ in range(n_slices)) + + self.indices: Optional[torch.Tensor] = None + self.output_dim = self.lora_b_stacked[0].shape[2] + + def reset_lora(self, index: int): + self.lora_a_stacked[0][index] = 0 + self.lora_a_stacked[1][index] = 0 + self.lora_b_stacked[0][index] = 0 + self.lora_b_stacked[1][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_dim + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[0][:, + start_idx:end_idx], lora_b[1][:, + start_idx:end_idx] + + if lora_a[0] is not None: + self.lora_a_stacked[0][ + index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( + lora_a[0].T, non_blocking=True) + self.lora_b_stacked[0][ + index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( + lora_b[0].T, non_blocking=True) + if lora_a[1] is not None: + self.lora_a_stacked[1][ + index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( + lora_a[1].T, non_blocking=True) + self.lora_b_stacked[1][ + index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( + lora_b[1].T, non_blocking=True) + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora_packed_nslice( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + (self.output_dim, self.output_dim), + ) + return output + + +class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + + # q, k, v + self.lora_a_stacked = (torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + )) + self.lora_b_stacked = (torch.zeros( + max_loras, + 1, + self.q_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + )) + + self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size, + self.kv_proj_shard_size) + self.packed_indices: Optional[torch.Tensor] = None + self.standard_indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + + def reset_lora(self, index: int): + self.lora_a_stacked[0][index] = 0 + self.lora_b_stacked[0][index] = 0 + self.lora_a_stacked[1][index] = 0 + self.lora_b_stacked[1][index] = 0 + self.lora_a_stacked[2][index] = 0 + self.lora_b_stacked[2][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + if lora_b[0] is not None: + lora_b_q = lora_b[0][:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + self.lora_b_stacked[0][ + index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( + lora_b_q.T, non_blocking=True) + if lora_b[1] is not None: + lora_b_k = lora_b[1][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + self.lora_b_stacked[1][ + index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( + lora_b_k.T, non_blocking=True) + if lora_b[2] is not None: + lora_b_v = lora_b[2][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + self.lora_b_stacked[2][ + index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( + lora_b_v.T, non_blocking=True) + else: + if lora_b[0] is not None: + self.lora_b_stacked[0][ + index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( + lora_b[0].T, non_blocking=True) + if lora_b[1] is not None: + self.lora_b_stacked[1][ + index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( + lora_b[1].T, non_blocking=True) + if lora_b[2] is not None: + self.lora_b_stacked[2][ + index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_( + lora_b[2].T, non_blocking=True) + + if lora_a[0] is not None: + self.lora_a_stacked[0][ + index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( + lora_a[0].T, non_blocking=True) + if lora_a[1] is not None: + self.lora_a_stacked[1][ + index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( + lora_a[1].T, non_blocking=True) + if lora_a[2] is not None: + self.lora_a_stacked[2][ + index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( + lora_a[2].T, non_blocking=True) + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora_packed_nslice( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + self.output_slices, + ) + return output + + +class RowParallelLinearWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.weight.shape[0], + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + if self.base_layer.tp_size > 1: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.base_layer.weight.shape[1] + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.indices_len = indices_len + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x) + _apply_lora( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + ) + return output + + def forward(self, input_): + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply_weights(input_parallel) + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = (output_ + self.base_layer.bias + if self.base_layer.bias is not None else output_) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + return output, output_bias + + @property + def weight(self): + return self.base_layer.weight + + +class SamplerWithLoRA(BaseLayerWithLoRA): + + def __init__( + self, + base_layer: Sampler, + hidden_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + # Keep this in sync with csrc/punica/bgmv/bgmv_config.h + if 32000 < self.base_layer.vocab_size > 33024: + raise ValueError( + "When using LoRA, vocab size must be 32000 >= vocab_size <= 33024" + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + # Pad for kernel compatibility + math.ceil(self.base_layer.vocab_size / + lora_config.lora_vocab_padding_size) * + lora_config.lora_vocab_padding_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + self.indices = None + self.indices_padded = None + self.indices_len = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1], ] = embeddings_tensor + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = sampler_indices + self.indices_padded = sampler_indices_padded + self.indices_len = indices_len + + def _get_logits( + self, + hidden_states: torch.Tensor, + embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_all_gather(logits) + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + lora_logits[-1] = float("-inf") + lora_logits = lora_logits.mT + lora_logits = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, + self.indices_padded[:self.indices_len[2]]).nan_to_num_( + nan=float("-inf"), + posinf=float("inf"), + neginf=float("-inf"))) + logits[:, + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + + lora_logits.shape[1]] = lora_logits + + _apply_lora( + hidden_states, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[1]], + logits, + ) + + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + + return logits + + def forward(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + +def from_layer( + layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA: + supported_layer_types = { + VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, + ColumnParallelLinear: ColumnParallelLinearWithLoRA, + QKVParallelLinear: QKVParallelLinearWithLora, + MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, + RowParallelLinear: RowParallelLinearWithLoRA, + } + for src_layer_type, lora_layer_type in supported_layer_types.items(): + if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck + ret = lora_layer_type(layer) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + return layer + + +def from_layer_sampler( + layer: Sampler, + lm_head: ParallelLMHead, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, +) -> SamplerWithLoRA: + ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype, + lm_head.weight.device) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py new file mode 100644 index 0000000000000..fbb228c9582d4 --- /dev/null +++ b/vllm/lora/lora.py @@ -0,0 +1,160 @@ +from typing import List, Optional + +import torch +from vllm.utils import in_wsl + + +class LoRALayerWeights: + """LoRA weights for a layer composed of two low rank matrixes.""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alpha: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + ) -> None: + self.module_name = module_name + self.rank = rank + self.lora_alpha = lora_alpha + self.lora_a = lora_a + self.lora_b = lora_b + self.embeddings_tensor = embeddings_tensor + + if scaling is None: + self.scaling = self.lora_alpha / self.rank + else: + self.scaling = scaling + + def optimize(self) -> "LoRALayerWeights": + """Optimize the LoRA by merging the scaling into lora_b.""" + if self.scaling == 1: + return + self.lora_b *= self.scaling + self.scaling = 1 + return self + + @property + def input_dim(self) -> int: + return self.lora_a.shape[0] + + @property + def output_dim(self) -> int: + return self.lora_b.shape[1] + + @property + def is_packed(self) -> bool: + return False + + @property + def extra_vocab_size(self) -> int: + return self.embeddings_tensor.shape[ + 0] if self.embeddings_tensor is not None else 0 + + @classmethod + def create_dummy_lora_weights( + cls, + module_name: str, + input_dim: int, + output_dim: int, + rank: int, + dtype: torch.dtype, + device: torch.device, + embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": + pin_memory = str(device) == "cpu" and not in_wsl() + lora_a = torch.zeros([input_dim, rank], + dtype=dtype, + device=device, + pin_memory=pin_memory) + lora_b = torch.zeros([rank, output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) + embeddings_tensor = torch.rand( + 10, + embeddings_tensor_dim, + dtype=dtype, + device=device, + pin_memory=pin_memory) if embeddings_tensor_dim else None + return cls( + module_name, + rank=rank, + lora_alpha=1, + lora_a=lora_a, + lora_b=lora_b, + embeddings_tensor=embeddings_tensor, + ) + + +class PackedLoRALayerWeights(LoRALayerWeights): + """LoRA used for packed layers (eg. qkv_proj).""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alphas: List[int], + lora_a: List[torch.Tensor], + lora_b: List[torch.Tensor], + scaling: Optional[List[float]] = None, + ) -> None: + super().__init__( + module_name=module_name, + rank=rank, + lora_alpha=0, + lora_a=lora_a, + lora_b=lora_b, + scaling=scaling, + embeddings_tensor=None, + ) + self.lora_alphas = lora_alphas + if scaling is None: + self.scaling = [ + lora_alpha / self.rank for lora_alpha in self.lora_alphas + ] + + @classmethod + def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights": + """Pack a list of LoRAs into a single LoRA. + + If LoRA is None, it signifies that the submodule does not have a LoRA. + """ + first_lora = next(lora for lora in loras if lora is not None) + for lora in loras: + if lora is None: + continue + lora.optimize() + rank = first_lora.rank + module_name = first_lora.module_name + obj = cls( + module_name, + rank, + [lora.lora_alpha if lora is not None else None for lora in loras], + [lora.lora_a if lora is not None else None for lora in loras], + [lora.lora_b if lora is not None else None for lora in loras], + scaling=[1 if lora is not None else None for lora in loras]) + return obj + + def optimize(self) -> "PackedLoRALayerWeights": + """Optimize the LoRA by merging the scaling into lora_b.""" + for i in range(len(self.lora_b)): + if self.scaling[i] == 1 or self.lora_b[i] is None: + continue + self.lora_b[i] *= self.scaling[i] + self.scaling[i] = 1 + return self + + @property + def input_dim(self) -> int: + raise NotImplementedError() + + @property + def output_dim(self) -> int: + raise NotImplementedError() + + @property + def is_packed(self) -> bool: + return True diff --git a/vllm/lora/models.py b/vllm/lora/models.py new file mode 100644 index 0000000000000..6c78c4a2c7771 --- /dev/null +++ b/vllm/lora/models.py @@ -0,0 +1,654 @@ +import copy +import json +import logging +import math +import os +import re +from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type, + Union) + +import safetensors.torch +import torch +from torch import nn + +from vllm.config import LoRAConfig +from vllm.utils import LRUCache, in_wsl + +from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, from_layer, from_layer_sampler +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule + +logger = logging.getLogger(__name__) + +# TODO: The mappings below should be moved to individual model classes. + +PACKED_MODULES_CFG = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], +} + +TARGET_MODULES_QKV = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", +] + +EMBEDDING_MODULES = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", +} + +EMBEDDING_PADDING_MODULES = ["lm_head"] + +_GLOBAL_LORA_ID = 0 + + +def convert_mapping( + mapping: LoRAMapping, lora_index_to_id: List[Optional[int]], + max_loras: int, vocab_size: int, extra_vocab_size: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + """Converts LoRAMapping to index tensors. + + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_index_to_id: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + indices_len: List of lengths of the above tensors. + """ + indices = list(mapping.index_mapping).copy() + embedding_indices = indices.copy() + lora_indices = indices.copy() + prompt_mapping = [ + lora_index_to_id.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(indices)): + # TODO index can be slow. optimize + lora_idx = (lora_index_to_id.index(indices[i]) + if indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if indices[i] > 0 else 0 + indices[i] = i + lora_indices[i] = lora_idx + + indices = torch.tensor([indices, lora_indices, embedding_indices], + dtype=torch.long, + device="cuda") + prompt_mapping = torch.tensor(prompt_mapping, + device="cuda", + dtype=torch.long) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size) + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = ( + torch.arange( + 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + + (sampler_indices_padded * len(sampler_indices_padded))) + indices_len = (base_indices.shape[-1], sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1]) + + return (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, indices_len) + + +def get_lora_id(): + global _GLOBAL_LORA_ID + _GLOBAL_LORA_ID += 1 + return _GLOBAL_LORA_ID + + +class LoRAModel: + """A LoRA fine-tuned model.""" + + def __init__( + self, + lora_model_id: int, + rank: int, + loras: Dict[str, LoRALayerWeights], + ) -> None: + self.id = lora_model_id + assert (lora_model_id > + 0), f"a valid lora id should be greater than 0, got {self.id}" + self.rank = rank + self.loras: Dict[str, LoRALayerWeights] = loras + + @property + def extra_vocab_size(self) -> int: + return max(lora.extra_vocab_size + for lora in self.loras.values()) if self.loras else 0 + + def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]: + """Get LoRA for a given module by name""" + return self.loras.get(module_name, None) + + # (yard1): TODO see if we can derive target_embedding_padding automatically + @classmethod + def from_lora_tensors( + cls, + lora_model_id: int, + rank: int, + lora_alpha: int, + tensors: Dict[str, torch.Tensor], + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + embeddings: Optional[Dict[str, torch.Tensor]] = None, + target_embedding_padding: Optional[int] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a dictionary of tensors.""" + pin_memory = str(device) == "cpu" and not in_wsl() + loras: Dict[str, LoRALayerWeights] = {} + for tensor_name, tensor in tensors.items(): + module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) + if module_name not in loras: + lora_embeddings_tensor = None + if embeddings: + embeddings_module = next( + (k for k in EMBEDDING_MODULES if k in module_name), + None) + if embeddings_module: + lora_embeddings_tensor = embeddings[ + EMBEDDING_MODULES[embeddings_module]].to( + device=device, dtype=dtype) + if pin_memory: + lora_embeddings_tensor = ( + lora_embeddings_tensor.pin_memory()) + loras[module_name] = LoRALayerWeights(module_name, rank, + lora_alpha, None, None, + lora_embeddings_tensor) + if is_lora_a: + loras[module_name].lora_a = tensor.to(device=device, + dtype=dtype).t() + if pin_memory: + loras[module_name].lora_a = loras[ + module_name].lora_a.pin_memory() + else: + loras[module_name].lora_b = tensor.to(device=device, + dtype=dtype).t() + if any(name in module_name + for name in EMBEDDING_PADDING_MODULES + ) and target_embedding_padding is not None: + lora_b = loras[module_name].lora_b + assert target_embedding_padding >= lora_b.shape[1] + addition = target_embedding_padding - lora_b.shape[1] + loras[module_name].lora_b = torch.nn.functional.pad( + lora_b, (0, addition)) + if pin_memory: + loras[module_name].lora_b = loras[ + module_name].lora_b.pin_memory() + + for lora in loras.values(): + lora.optimize() + return cls(lora_model_id, rank, loras) + + @classmethod + def from_local_checkpoint( + cls, + lora_dir: str, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None) -> "LoRAModel": + """Create a LoRAModel from a local checkpoint.""" + lora_config_path = os.path.join(lora_dir, "adapter_config.json") + lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") + lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + new_embeddings_tensor_path = os.path.join( + lora_dir, "new_embeddings.safetensors") + new_embeddings_bin_file_path = os.path.join(lora_dir, + "new_embeddings.bin") + if os.path.isfile(lora_tensor_path): + tensors = safetensors.torch.load_file(lora_tensor_path) + elif os.path.isfile(lora_bin_file_path): + tensors = torch.load(lora_bin_file_path) + else: + raise ValueError(f"{lora_dir} doesn't contain tensors") + + embeddings = None + if os.path.isfile(new_embeddings_tensor_path): + embeddings = safetensors.torch.load_file( + new_embeddings_tensor_path) + elif os.path.isfile(new_embeddings_bin_file_path): + embeddings = torch.load(new_embeddings_bin_file_path) + + with open(lora_config_path) as f: + config = json.load(f) + rank = config["r"] + lora_alpha = config["lora_alpha"] + return cls.from_lora_tensors( + lora_model_id=get_lora_id() + if lora_model_id is None else lora_model_id, + rank=rank, + lora_alpha=lora_alpha, + tensors=tensors, + device=device, + dtype=dtype, + embeddings=embeddings, + target_embedding_padding=target_embedding_padding, + ) + + +class LoRAModelManager: + """A manager that manages multiple LoRA-fine-tuned models.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG, + ): + """Create a LoRAModelManager and adapter for a given model. + + Args: + model: the model to be adapted. + max_num_seqs: the maximum number of sequences model can run in a + single batch. + max_num_batched_tokens: the maximum number of tokens model can run + in a single batch. + vocab_size: the vocab size of the model. + lora_config: the LoRA configuration. + lora_target_modules: the target modules patterns to be adapted. + Support both single module name and a list of module names. + packed_modules_mapping: the mapping for packed modules. vLLM + packs some modules into one module, e.g., qkv_proj + is packed of q_proj, k_proj, and v_proj. These modules + have a single layer in the original model, but they are split + into multiple layers in the adapted model. + """ + self.lora_config = lora_config + self.max_num_seqs = max_num_seqs + assert self.capacity >= self.lora_slots + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots + self.vocab_size = vocab_size + self.base_indices = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.sampler_indices = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.embeddings_indices = torch.empty(2, + self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.offsets = [] + # 4 is the number of indicies tensors defined above + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices + self.indices_len = [None] * 4 + + self.model: nn.Module = model + self.lora_target_modules: List[str] = ([ + lora_target_modules + ] if isinstance(lora_target_modules, str) else lora_target_modules) + self.lora_target_modules = copy.deepcopy(lora_target_modules) + self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping) + self.packed_modules: Dict[str, List[str]] = {} + self.modules: Dict[str, "BaseLayerWithLoRA"] = {} + self._registered_loras: Dict[int, LoRAModel] = {} + # Dict instead of a Set for compatibility with LRUCache. + self._active_loras: Dict[int, None] = {} + self._last_mapping = None + self._create_lora_modules() + self.model.lora_manager = self + + @property + def capacity(self) -> int: + return self.lora_config.max_cpu_loras + + @property + def lora_slots(self) -> int: + return self.lora_config.max_loras + + def __len__(self) -> int: + return len(self._registered_loras) + + def activate_lora( + self, + lora_id: int, + ) -> bool: + """Move LoRA into a GPU buffer to be used in the forward pass.""" + if lora_id in self._active_loras: + return False + first_free_slot = next( + ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) + if lora_id is None), None) + if first_free_slot is None: + raise ValueError("No free lora slots") + index, _ = first_free_slot + self._active_loras[lora_id] = None + lora_model = self._registered_loras[lora_id] + logger.debug( + f"Activating LoRA. int id: {lora_model.id}, slot index: {index}") + self.lora_index_to_id[index] = lora_model.id + for module_name, module in self.modules.items(): + module_lora = lora_model.get_lora(module_name) + if module_lora: + module_lora.optimize() + module.set_lora(index, module_lora.lora_a, module_lora.lora_b, + module_lora.embeddings_tensor) + else: + module.reset_lora(index) + return True + + def _deactivate_lora(self, lora_id: int): + try: + index = self.lora_index_to_id.index(lora_id) + self.lora_index_to_id[index] = None + except ValueError: + pass + + def deactivate_lora(self, lora_id: int) -> bool: + """Remove a LoRA from a GPU buffer.""" + if lora_id in self._active_loras: + self._deactivate_lora(lora_id) + self._active_loras.pop(lora_id) + return True + return False + + def _add_lora(self, lora: LoRAModel) -> bool: + self._create_merged_loras_inplace(lora) + self._registered_loras[lora.id] = lora + + def add_lora(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager CPU cache.""" + if lora.id not in self._registered_loras: + if len(self._registered_loras) >= self.capacity: + raise RuntimeError("No free LoRA slots.") + self._add_lora(lora) + return True + return False + + def remove_lora(self, lora_id: int) -> bool: + """Remove a LoRAModel from the manager CPU cache.""" + # TODO: should we check active lora? + self.deactivate_lora(lora_id) + return bool(self._registered_loras.pop(lora_id, None)) + + # TODO see if this can be vectorized + def _set_lora_mapping(self, mapping: LoRAMapping) -> None: + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, + indices_len) = convert_mapping(mapping, self.lora_index_to_id, + self.lora_slots + 1, self.vocab_size, + self.lora_config.lora_extra_vocab_size) + self.base_indices[:base_indices.shape[0]].copy_(base_indices) + self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self.embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + # Maintain the reference + self.indices_len[:] = indices_len + + def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None: + if self._last_mapping != lora_mapping: + self._set_lora_mapping(lora_mapping) + self._last_mapping = lora_mapping + + def list_loras(self) -> Dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_loras) + + def get_lora(self, lora_id: int) -> Optional[LoRAModel]: + return self._registered_loras.get(lora_id, None) + + def remove_all_loras(self) -> bool: + """Remove all LoRAModels from the manager.""" + self._registered_loras.clear() + self.lora_index_to_id = [None] * self.lora_slots + self._active_loras.clear() + + def _create_lora_modules(self): + for module_name, module in self.model.named_modules(): + if not self._match_target_modules(module_name): + continue + + new_module = replace_submodule( + self.model, module_name, + from_layer(module, self.lora_slots, self.lora_config, + self.model.config)) + # (yard1): TODO make this more robust + if "lm_head" in module_name: + sampler_module = self.model.get_submodule("sampler") + new_module = replace_submodule( + self.model, "sampler", + from_layer_sampler(sampler_module, module, self.lora_slots, + self.lora_config, self.model.config)) + self.register_module(module_name, new_module) + self._register_packed_modules(module_name) + new_module.set_mapping(self.base_indices, self.sampler_indices, + self.sampler_indices_padded, + self.embeddings_indices, self.indices_len) + + def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): + assert isinstance(module, BaseLayerWithLoRA) + self.modules[module_name] = module + + def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: + """Create zero-initialized LoRAModel for warmup.""" + model = LoRAModel(lora_id, rank, {}) + for module_name, module in self.model.named_modules(): + if not self._match_target_modules(module_name) or not isinstance( + module, BaseLayerWithLoRA): + continue + parts = module_name.split(".") + if module_name not in self.packed_modules: + if parts[-1] in EMBEDDING_MODULES: + input_dim = (module.base_layer.org_vocab_size + + self.lora_config.lora_extra_vocab_size if + hasattr(module.base_layer, "org_vocab_size") + else module.base_layer.weight.shape[1]) + output_dim = module.base_layer.embedding_dim if hasattr( + module.base_layer, + "embedding_dim") else module.base_layer.weight.shape[0] + embeddings_tensor_dim = (module.base_layer.embedding_dim if + hasattr(module.base_layer, + "embedding_dim") else + module.base_layer.weight.shape[1]) + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + input_dim, + output_dim, + rank, + module.lora_a_stacked.dtype, + "cpu", + embeddings_tensor_dim=embeddings_tensor_dim) + else: + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + module.lora_a_stacked.shape[-1], + module.lora_b_stacked.shape[-2], + rank, + module.lora_a_stacked.dtype, + "cpu", + ) + lora.optimize() + else: + parts = module_name.split(".") + replacements = self.packed_modules_mapping[parts[-1]] + subloras = [] + for i, r in enumerate(replacements): + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name + "." + r, + module.lora_a_stacked[i].shape[-1], + module.lora_b_stacked[i].shape[-2], + rank, + module.lora_a_stacked[i].dtype, + "cpu", + ) + lora.optimize() + subloras.append(lora) + lora = PackedLoRALayerWeights.pack(subloras) + model.loras[module_name] = lora + return model + + def _match_target_modules(self, module_name: str): + return any( + re.match( + r".*\.{target_module}$".format(target_module=target_module), + module_name) or target_module == module_name + for target_module in self.lora_target_modules) + + def _register_packed_modules(self, module_full_name: str) -> None: + parts = module_full_name.split(".") + module_name = parts[-1] + replacements = self.packed_modules_mapping.get(module_name) + if not replacements: + return + prefix = ".".join(parts[:-1]) + self.packed_modules[module_full_name] = [ + prefix + "." + r if prefix else r for r in replacements + ] + + def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: + for module_name, new_module_names in self.packed_modules.items(): + replacement_loras = [] + has_replacement = False + for r in new_module_names: + lora = lora_model.get_lora(r) + replacement_loras.append(lora) + if lora: + has_replacement = True + if not has_replacement: + continue + for i in range(len(replacement_loras)): + if replacement_loras[i]: + continue + replacement_loras[i] = None + lora_model.loras[module_name] = PackedLoRALayerWeights.pack( + replacement_loras) + + +class LoRALRUCache(LRUCache): + + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], + None]): + super().__init__(capacity) + self.deactivate_lora_fn = deactivate_lora_fn + + def _on_remove(self, key: Hashable, value: Any): + logger.debug(f"Removing LoRA. int id: {key}") + self.deactivate_lora_fn(key) + return super()._on_remove(key, value) + + +class LRUCacheLoRAModelManager(LoRAModelManager): + """A model manager that manages multiple LoRAs with LRU cache.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG, + ): + super().__init__(model, max_num_seqs, max_num_batched_tokens, + vocab_size, lora_config, lora_target_modules, + packed_modules_mapping) + self._registered_loras: LoRALRUCache = LoRALRUCache( + self.capacity, self.deactivate_lora) + self._active_loras: LoRALRUCache = LoRALRUCache( + self.lora_slots, self._deactivate_lora) + + def list_loras(self) -> Dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_loras.cache) + + def add_lora(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager.""" + if lora.id not in self._registered_loras: + self._add_lora(lora) + was_added = True + else: + # We always touch to update the LRU cache order + self._registered_loras.touch(lora.id) + was_added = False + return was_added + + def activate_lora( + self, + lora_id: int, + ) -> bool: + if lora_id not in self._active_loras and len( + self._active_loras) >= self.lora_slots: + self._active_loras.remove_oldest() + result = super().activate_lora(lora_id) + # We always touch to update the LRU cache order + self._active_loras.touch(lora_id) + return result + + def remove_oldest_lora(self) -> bool: + if len(self._registered_loras) > 0: + self._registered_loras.remove_oldest() + return True + return False + + +def create_lora_manager( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, + **kwargs) -> LoRAModelManager: + """Create a LoRA adapter for a given model.""" + if not getattr(model, "supports_lora", False): + raise ValueError(f"Model {type(model)} is not supported for LoRA.") + lora_manager = lora_manager_cls( + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + vocab_size=vocab_size, + lora_config=lora_config, + lora_target_modules=target_modules, + **kwargs) + return lora_manager diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py new file mode 100644 index 0000000000000..ac96931b2d071 --- /dev/null +++ b/vllm/lora/punica.py @@ -0,0 +1,173 @@ +# Based on code from https://github.com/punica-ai/punica + +from typing import Optional + +import torch + +import_exc = None + +try: + import vllm._punica_C as punica_kernels +except ImportError as e: + import_exc = e + +if import_exc is None: + + def bgmv( + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + ): + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight + matrices. + indicies: Shape: `[B]`. Indices of the weight matrices. + layer_idx: Layer index of the weight matrices. + scale: Scaling factor. + """ + punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) + + def add_lora(y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + *, + buffer: Optional[torch.Tensor] = None): + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed + LoRA A matrices. + wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed + LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + buffer: Optional. Shape: `[B, R]`. Temporary buffer. + """ + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default to avoid + # numerical innacuracies that would otherwise happen + # due to downcasting. + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, + 1.0) + punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, + scale) + + def add_lora_slice(y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + y_offset: int, + y_slice_size: int, + *, + buffer: Optional[torch.Tensor] = None): + """ + Same as `add_lora` but you can operate on slices of y. + Pass whole y, define y_offset and y_slice_size. + + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed + LoRA A matrices. + wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed + LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + y_offset: Offset to apply to the starting column of y. + y_slice_size: Size of the y column slice. + """ + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default to avoid + # numerical inaccuracies that would otherwise happen + # due to downcasting. + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + punica_kernels.dispatch_bgmv_low_level( + buffer, + x, + wa_t_all, + indicies, + layer_idx, + 1.0, + x.size(1), + buffer.size(1), + 0, + ) + punica_kernels.dispatch_bgmv_low_level( + y, + buffer, + wb_t_all, + indicies, + layer_idx, + scale, + buffer.size(1), + y_slice_size, + y_offset, + ) + +else: + + def _raise_exc( + *args, # pylint: disable=unused-argument + **kwargs # pylint: disable=unused-argument + ): + if torch.cuda.get_device_capability() < (8, 0): + raise ImportError( + "LoRA kernels require compute capability>=8.0") from import_exc + else: + raise import_exc + + bgmv = _raise_exc + add_lora = _raise_exc + add_lora_slice = _raise_exc + +__all__ = [ + "bgmv", + "add_lora", + "add_lora_slice", +] diff --git a/vllm/lora/request.py b/vllm/lora/request.py new file mode 100644 index 0000000000000..5d45f8a0f396d --- /dev/null +++ b/vllm/lora/request.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass + + +@dataclass +class LoRARequest: + """ + Request for a LoRA adapter. + + Note that this class should be be used internally. For online + serving, it is recommended to not allow users to use this class but + instead provide another layer of abstraction to prevent users from + accessing unauthorized LoRA adapters. + + lora_id and lora_int_id must be globally unique for a given adapter. + This is currently not enforced in vLLM. + """ + + lora_id: str + lora_int_id: int + lora_local_path: str + + def __post_init__(self): + if self.lora_int_id < 1: + raise ValueError( + f"lora_int_id must be > 0, got {self.lora_int_id}") + + def __eq__(self, value: object) -> bool: + return isinstance(value, LoRARequest) and self.lora_id == value.lora_id + + def __hash__(self) -> int: + return self.lora_int_id diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py new file mode 100644 index 0000000000000..f67a3812fb046 --- /dev/null +++ b/vllm/lora/utils.py @@ -0,0 +1,39 @@ +import logging +from typing import Tuple + +from torch import nn + +logger = logging.getLogger(__name__) + + +def replace_submodule(model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + +def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: + """Parse the name of lora weights. + + args: + name: the name of the fine-tuned LoRA, e.g. + base_model.model.dense1.weight + return: + Tuple(module_name, is_lora_a): + module_name: the name of the module, e.g. model.dense1, + is_lora_a whether the tensor is lora_a or lora_b. + """ + parts = name.split(".") + assert parts[0] == "base_model" + assert parts[1] == "model" + if parts[-1] == "weight": + assert parts[-2] == "lora_A" or parts[-2] == "lora_B" + return ".".join(parts[2:-2]), parts[-2] == "lora_A" + + if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": + return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" + + raise ValueError(f"{name} is unsupported format") diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py new file mode 100644 index 0000000000000..a507c08588dad --- /dev/null +++ b/vllm/lora/worker_manager.py @@ -0,0 +1,237 @@ +import logging +from abc import ABC, abstractmethod, abstractproperty +from typing import Any, List, Optional, Set, Type, Union + +import torch + +from vllm.lora.models import (TARGET_MODULES_QKV, LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, create_lora_manager) +from vllm.lora.request import LoRARequest +from vllm.lora.layers import LoRAMapping +from vllm.config import LoRAConfig + +logger = logging.getLogger(__name__) + + +class WorkerLoRAManager(ABC): + """Abstract class for managing LoRA models on the worker side.""" + + def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, + vocab_size: int, lora_config: LoRAConfig, + device: torch.device): + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.device = device + self.lora_config = lora_config + + @abstractproperty + def is_enabled(self) -> bool: + ... + + @abstractmethod + def create_lora_manager( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + ... + + @abstractmethod + def set_active_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + ... + + @abstractmethod + def add_lora(self, lora_request: LoRARequest) -> bool: + ... + + @abstractmethod + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + ... + + @abstractmethod + def remove_lora(self, lora_id: int) -> bool: + ... + + @abstractmethod + def remove_all_loras(self) -> bool: + ... + + @abstractmethod + def list_loras(self) -> Set[int]: + ... + + +class WorkerLoRAManager(WorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Every request, the requested LoRAs will be loaded (unless they are already + loaded), and every other LoRA will be unloaded.""" + + _lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + lora_model_cls: Type[LoRAModel] = LoRAModel, + ): + self._lora_manager: Optional[LoRAModelManager] = None + self._lora_model_cls = lora_model_cls + super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, + lora_config, device) + + @property + def is_enabled(self) -> bool: + return True + + def create_lora_manager( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + lora_manager = create_lora_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + target_modules=target_modules, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + lora_manager_cls=self._lora_manager_cls, + ) + self._lora_manager: LoRAModelManager = lora_manager + return lora_manager.model + + def set_active_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + self._apply_loras(lora_requests) + self._lora_manager.set_lora_mapping(lora_mapping) + + def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + loras_that_exist = self.list_loras() + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._lora_manager.lora_slots: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._lora_manager.lora_slots}).") + + new_loras = set(loras_map) + loras_to_add = new_loras - loras_that_exist + loras_to_remove = loras_that_exist - new_loras + + for lora_id in loras_to_remove: + self.remove_lora(lora_id) + + for lora_id in loras_to_add: + self.add_lora(loras_map[lora_id]) + + def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: + try: + lora = self._lora_model_cls.from_local_checkpoint( + lora_request.lora_local_path, + lora_model_id=lora_request.lora_int_id, + device="cpu", + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, + ) + except Exception as e: + raise RuntimeError( + f"Loading lora {lora_request.lora_local_path} failed") from e + if lora.rank > self.lora_config.max_lora_rank: + raise ValueError( + f"LoRA rank {lora.rank} is greater than max_lora_rank " + f"{self.lora_config.max_lora_rank}.") + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError( + f"LoRA added vocab size {lora.extra_vocab_size} is greater than " + f"lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}." + ) + return lora + + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + if lora_request.lora_int_id in self.list_loras(): + return False + return self._lora_manager.add_lora( + self._lora_manager.create_dummy_lora(lora_request.lora_int_id, + rank)) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id in self.list_loras(): + return False + lora = self._load_lora(lora_request) + loaded = self._lora_manager.add_lora(lora) + self._lora_manager.activate_lora(lora.id) + return loaded + + def remove_lora(self, lora_id: int) -> bool: + return self._lora_manager.remove_lora(lora_id) + + def remove_all_loras(self) -> bool: + self._lora_manager.remove_all_loras() + + def list_loras(self) -> Set[int]: + return set(self._lora_manager.list_loras()) + + +class LRUCacheWorkerLoRAManager(WorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Uses an LRU Cache. Every request, the requested LoRAs will be loaded + (unless they are already loaded) and least recently used LoRAs will + be unloaded if the cache is above capacity.""" + + _lora_manager_cls: Type[ + LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + + def create_lora_manager( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + lora_manager = create_lora_manager( + model, + target_modules=target_modules, + lora_manager_cls=self._lora_manager_cls, + max_num_seqs=self.max_num_seqs, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + max_num_batched_tokens=self.max_num_batched_tokens, + ) + self._lora_manager: LRUCacheLoRAModelManager = lora_manager + return lora_manager.model + + def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._lora_manager.lora_slots: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._lora_manager.lora_slots}).") + for lora in loras_map.values(): + self.add_lora(lora) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id not in self.list_loras(): + # Remove before we load the new lora to save memory + if len(self._lora_manager) + 1 > self._lora_manager.capacity: + self._lora_manager.remove_oldest_lora() + lora = self._load_lora(lora_request) + loaded = self._lora_manager.add_lora(lora) + else: + # If the lora is already loaded, just touch it to + # update its position in the caches + loaded = self._lora_manager.get_lora(lora_request.lora_int_id) + self._lora_manager.activate_lora(lora_request.lora_int_id) + return loaded diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index f9d95fa7548fd..72cebdce8090d 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -27,9 +27,24 @@ class Sampler(nn.Module): parameters (e.g., sampling method, temperature, top-p, top-k, etc.). """ - def __init__(self, vocab_size: int) -> None: + def __init__(self, + vocab_size: int, + org_vocab_size: Optional[int] = None) -> None: super().__init__() self.vocab_size = vocab_size + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + + def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + logits = logits[:, :self.org_vocab_size] + return logits def forward( self, @@ -42,8 +57,7 @@ def forward( hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) # Get the logits for the next tokens. - logits = _get_logits(hidden_states, embedding, embedding_bias, - self.vocab_size) + logits = self._get_logits(hidden_states, embedding, embedding_bias) _, vocab_size = logits.shape @@ -90,19 +104,6 @@ def forward( prompt_logprobs, sample_logprobs) -def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, - embedding_bias: Optional[torch.Tensor], - vocab_size: int) -> torch.Tensor: - # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) - if embedding_bias is not None: - logits += embedding_bias - logits = tensor_model_parallel_all_gather(logits) - # Remove paddings in vocab (if any). - logits = logits[:, :vocab_size] - return logits - - def _prune_hidden_states( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index b08d5555b0faa..9c5fb890251ed 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -13,8 +13,11 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.utils import set_weight_attrs +DEFAULT_VOCAB_PADDING_SIZE = 64 -def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int: + +def pad_vocab_size(vocab_size: int, + pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: """Pad the vocab size to the given value.""" return ((vocab_size + pad_to - 1) // pad_to) * pad_to @@ -43,17 +46,23 @@ class VocabParallelEmbedding(torch.nn.Module): num_embeddings: vocabulary size. embedding_dim: size of hidden state. params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. """ def __init__(self, num_embeddings: int, embedding_dim: int, - params_dtype: Optional[torch.dtype] = None): + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): super().__init__() # Keep the input dimensions. self.num_embeddings = num_embeddings - self.num_embeddings_padded = pad_vocab_size(num_embeddings) + self.org_vocab_size = org_num_embeddings or num_embeddings + self.num_embeddings_padded = pad_vocab_size(num_embeddings, + padding_size) self.embedding_dim = embedding_dim if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -77,7 +86,7 @@ def __init__(self, def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): parallel_dim = param.parallel_dim - assert loaded_weight.shape[parallel_dim] == self.num_embeddings + assert loaded_weight.shape[parallel_dim] == self.org_vocab_size loaded_weight = loaded_weight[self.vocab_start_index:self. vocab_end_index] param[:loaded_weight.shape[0]].data.copy_(loaded_weight) @@ -114,14 +123,19 @@ class ParallelLMHead(VocabParallelEmbedding): embedding_dim: size of hidden state. bias: whether to use bias. params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. """ def __init__(self, num_embeddings: int, embedding_dim: int, bias: bool = False, - params_dtype: Optional[torch.dtype] = None): - super().__init__(num_embeddings, embedding_dim, params_dtype) + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): + super().__init__(num_embeddings, embedding_dim, params_dtype, + org_num_embeddings, padding_size) if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 37543d8c9838e..0f1125e5c8e3e 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -1,12 +1,12 @@ """Utilities for selecting and loading models.""" import contextlib -from typing import Type +from typing import Optional, Type import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import ModelConfig +from vllm.config import ModelConfig, LoRAConfig from vllm.model_executor.models import ModelRegistry from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) @@ -32,7 +32,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: f"Supported architectures: {ModelRegistry.get_supported_archs()}") -def get_model(model_config: ModelConfig) -> nn.Module: +def get_model(model_config: ModelConfig, + lora_config: Optional[LoRAConfig] = None) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) # Get the (maybe quantized) linear method. @@ -62,7 +63,17 @@ def get_model(model_config: ModelConfig) -> nn.Module: # Create a model instance. # The weights will be initialized as empty tensors. with torch.device("cuda"): - model = model_class(model_config.hf_config, linear_method) + if getattr(model_class, "supports_lora", False): + model = model_class(model_config.hf_config, linear_method, + lora_config) + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + else: + model = model_class(model_config.hf_config, linear_method) if model_config.load_format == "dummy": # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b3b24ea6fea44..ddae87b07b978 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -45,6 +45,7 @@ from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput +from vllm.config import LoRAConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -225,14 +226,19 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, + self.vocab_size, config.hidden_size, + org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ LlamaDecoderLayer(config, linear_method) @@ -263,18 +269,31 @@ def forward( class LlamaForCausalLM(nn.Module): + supports_lora = True def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = LlamaModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.model = LlamaModel(config, linear_method, lora_config=lora_config) + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + # We need bigger padding if using lora for kernel + # compatibility + padding_size=64 + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) def forward( self, diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 57230fcced9ff..058a219d3f5a6 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -45,6 +45,7 @@ from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput +from vllm.config import LoRAConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -220,15 +221,20 @@ def __init__( self, config: MistralConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, + self.vocab_size, config.hidden_size, + org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ MistralDecoderLayer(config, linear_method) @@ -259,18 +265,33 @@ def forward( class MistralForCausalLM(nn.Module): + supports_lora = True def __init__( self, config: MistralConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = MistralModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.model = MistralModel(config, + linear_method, + lora_config=lora_config) + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + # We need bigger padding if using lora for kernel + # compatibility + padding_size=64 + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) def forward( self, diff --git a/vllm/outputs.py b/vllm/outputs.py index fe54926e06e64..534e9d5ea8a53 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -2,6 +2,7 @@ from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, SequenceStatus) +from vllm.lora.request import LoRARequest class CompletionOutput: @@ -16,6 +17,7 @@ class CompletionOutput: logprobs: The log probabilities of the top probability words at each position if the logprobs are requested. finish_reason: The reason why the sequence is finished. + lora_request: The LoRA request that was used to generate the output. """ def __init__( @@ -26,6 +28,7 @@ def __init__( cumulative_logprob: float, logprobs: Optional[SampleLogprobs], finish_reason: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, ) -> None: self.index = index self.text = text @@ -33,6 +36,7 @@ def __init__( self.cumulative_logprob = cumulative_logprob self.logprobs = logprobs self.finish_reason = finish_reason + self.lora_request = lora_request def finished(self) -> bool: return self.finish_reason is not None @@ -56,6 +60,7 @@ class RequestOutput: prompt_logprobs: The log probabilities to return per prompt token. outputs: The output sequences of the request. finished: Whether the whole request is finished. + lora_request: The LoRA request that was used to generate the output. """ def __init__( @@ -66,6 +71,7 @@ def __init__( prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, + lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -73,6 +79,7 @@ def __init__( self.prompt_logprobs = prompt_logprobs self.outputs = outputs self.finished = finished + self.lora_request = lora_request @classmethod def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": @@ -108,8 +115,13 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": prompt_token_ids = seq_group.prompt_token_ids prompt_logprobs = seq_group.prompt_logprobs finished = seq_group.is_finished() - return cls(seq_group.request_id, prompt, prompt_token_ids, - prompt_logprobs, outputs, finished) + return cls(seq_group.request_id, + prompt, + prompt_token_ids, + prompt_logprobs, + outputs, + finished, + lora_request=seq_group.lora_request) def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " @@ -117,4 +129,5 @@ def __repr__(self) -> str: f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_logprobs={self.prompt_logprobs}, " f"outputs={self.outputs}, " - f"finished={self.finished})") + f"finished={self.finished}, " + f"lora_request={self.lora_request})") diff --git a/vllm/sequence.py b/vllm/sequence.py index 7d36eeac0aa02..ac33a514b166e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,6 +5,7 @@ from vllm.block import LogicalTokenBlock from vllm.sampling_params import SamplingParams +from vllm.lora.request import LoRARequest PromptLogprobs = List[Optional[Dict[int, float]]] SampleLogprobs = List[Dict[int, float]] @@ -105,6 +106,7 @@ class Sequence: prompt_token_ids: The token IDs of the prompt. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. + lora_request: LoRA request. """ def __init__( @@ -113,10 +115,12 @@ def __init__( prompt: str, prompt_token_ids: List[int], block_size: int, + lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id self.prompt = prompt self.block_size = block_size + self.lora_request = lora_request self.data = SequenceData(prompt_token_ids) self.output_logprobs: SampleLogprobs = [] @@ -133,6 +137,10 @@ def __init__( # Input + output tokens self.tokens: Optional[List[str]] = None + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + def _append_logical_block(self) -> None: block = LogicalTokenBlock( block_number=len(self.logical_token_blocks), @@ -228,6 +236,7 @@ class SequenceGroup: seqs: The list of sequences. sampling_params: The sampling parameters used to generate the outputs. arrival_time: The arrival time of the request. + lora_request: LoRA request. """ def __init__( @@ -236,11 +245,13 @@ def __init__( seqs: List[Sequence], sampling_params: SamplingParams, arrival_time: float, + lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.arrival_time = arrival_time + self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None @property @@ -255,6 +266,10 @@ def prompt_token_ids(self) -> List[int]: # We use the prompt of an arbitrary sequence. return next(iter(self.seqs_dict.values())).data.prompt_token_ids + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + def get_max_num_running_seqs(self) -> int: """The maximum number of sequences running in parallel in the remaining lifetime of the request.""" @@ -335,6 +350,7 @@ class SequenceGroupMetadata: sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) + lora_request: LoRA request. """ def __init__( @@ -344,12 +360,18 @@ def __init__( seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], + lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt self.seq_data = seq_data self.sampling_params = sampling_params self.block_tables = block_tables + self.lora_request = lora_request + + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 class SequenceOutput: diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index a67d2f83a2549..532c7a4e6c1dc 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -4,6 +4,8 @@ PreTrainedTokenizerFast) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.utils import make_async, LRUCache from vllm.transformers_utils.tokenizers import * logger = init_logger(__name__) @@ -65,6 +67,83 @@ def get_tokenizer( return tokenizer +def get_lora_tokenizer(lora_request: LoRARequest, *args, + **kwargs) -> Optional[PreTrainedTokenizer]: + if lora_request is None: + return None + try: + tokenizer = get_tokenizer(lora_request.lora_local_path, *args, + **kwargs) + except OSError as e: + # No tokenizer was found in the LoRA folder, + # use base model tokenizer + logger.warning( + f"No tokenizer found in {lora_request.lora_local_path}, " + "using base model tokenizer instead. " + f"(Exception: {str(e)})") + tokenizer = None + return tokenizer + + +get_lora_tokenizer_async = make_async(get_lora_tokenizer) + + +class MultiLoRATokenizer: + + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], **tokenizer_config): + self.tokenizer_id = tokenizer_id + self.tokenizer_config = tokenizer_config + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) + if enable_lora: + self.lora_tokenizers = LRUCache(capacity=max_num_seqs) + else: + self.lora_tokenizers = None + + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = self.get_lora_tokenizer(lora_request) + return tokenizer.encode(prompt) + + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = await self.get_lora_tokenizer_async(lora_request) + return tokenizer.encode(prompt) + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (get_lora_tokenizer( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (await get_lora_tokenizer_async( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], output_tokens: List[str], diff --git a/vllm/utils.py b/vllm/utils.py index eff5d10fd4ee0..37d18e7854c99 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -5,9 +5,20 @@ import psutil import torch +import asyncio +from functools import partial +from typing import ( + Awaitable, + Callable, + TypeVar, +) +from collections import OrderedDict +from typing import Any, Hashable, Optional from vllm._C import cuda_utils +T = TypeVar("T") + class Device(enum.Enum): GPU = enum.auto() @@ -32,6 +43,73 @@ def is_hip() -> bool: return torch.version.hip is not None +class LRUCache: + + def __init__(self, capacity: int): + self.cache = OrderedDict() + self.capacity = capacity + + def __contains__(self, key: Hashable) -> bool: + return key in self.cache + + def __len__(self) -> int: + return len(self.cache) + + def __getitem__(self, key: Hashable) -> Any: + return self.get(key) + + def __setitem__(self, key: Hashable, value: Any) -> None: + self.put(key, value) + + def __delitem__(self, key: Hashable) -> None: + self.pop(key) + + def touch(self, key: Hashable) -> None: + self.cache.move_to_end(key) + + def get(self, key: Hashable, default_value: Optional[Any] = None) -> int: + if key in self.cache: + value = self.cache[key] + self.cache.move_to_end(key) + else: + value = default_value + return value + + def put(self, key: Hashable, value: Any) -> None: + self.cache[key] = value + self.cache.move_to_end(key) + self._remove_old_if_needed() + + def _on_remove(self, key: Hashable, value: Any): + pass + + def remove_oldest(self): + if not self.cache: + return + key, value = self.cache.popitem(last=False) + self._on_remove(key, value) + + def _remove_old_if_needed(self) -> None: + while len(self.cache) > self.capacity: + self.remove_oldest() + + def pop(self, key: int, default_value: Optional[Any] = None) -> Any: + run_on_remove = key in self.cache + value = self.cache.pop(key, default_value) + if run_on_remove: + self._on_remove(key, value) + return value + + def clear(self): + while len(self.cache) > 0: + self.remove_oldest() + self.cache.clear() + + +def is_hip() -> bool: + return torch.version.hip is not None + + def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html @@ -55,6 +133,22 @@ def in_wsl() -> bool: return "microsoft" in " ".join(uname()).lower() +def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]: + """Take a blocking function, and run it on in an executor thread. + + This function prevents the blocking function from blocking the + asyncio event loop. + The code in this function needs to be thread safe. + """ + + def _async_wrapper(*args, **kwargs) -> asyncio.Future: + loop = asyncio.get_event_loop() + p_func = partial(func, *args, **kwargs) + return loop.run_in_executor(executor=None, func=p_func) + + return _async_wrapper + + def get_open_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fb7a0c17d6f9f..6d990e5eb63a0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,21 +1,25 @@ import time -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Set, Union import numpy as np import torch import torch.nn as nn -from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig +from vllm.config import ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig from vllm.logger import init_logger from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest from vllm.utils import in_wsl logger = init_logger(__name__) KVCache = Tuple[torch.Tensor, torch.Tensor] _PAD_SLOT_ID = -1 +LORA_WARMUP_RANK = 8 # Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # NOTE: _get_graph_batch_size needs to be updated if this list is changed. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] @@ -28,17 +32,21 @@ def __init__( model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + self.lora_config = lora_config # model_config can be None in tests/samplers/test_sampler.py. # FIXME(woosuk): This is a hack to make the tests work. Refactor this. self.sliding_window = (model_config.get_sliding_window() if model_config is not None else None) + self.device = torch.device(torch.cuda.current_device()) self.model = None self.block_size = None # Set after initial profiling. + self.lora_manager = None self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_memory_pool = None # Set during graph capture. @@ -57,7 +65,17 @@ def __init__( self.in_wsl = in_wsl() def load_model(self) -> None: - self.model = get_model(self.model_config) + self.model = get_model(self.model_config, self.lora_config) + + vocab_size = self.model.config.vocab_size + + if self.lora_config: + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens + + self.scheduler_config.max_paddings, vocab_size, + self.lora_config, self.device) + self.model = self.lora_manager.create_lora_manager(self.model) def set_block_size(self, block_size: int) -> None: self.block_size = block_size @@ -70,11 +88,15 @@ def set_block_size(self, block_size: int) -> None: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], + Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] + lora_index_mapping: List[int] = [] + lora_prompt_mapping: List[int] = [] + lora_requests: Set[LoRARequest] = set() prompt_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: @@ -93,6 +115,17 @@ def _prepare_prompt( # is always the first token in the sequence. input_positions.append(list(range(prompt_len))) + lora_id = seq_group_metadata.lora_int_id + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping.append([lora_id] * prompt_len) + lora_prompt_mapping.extend( + [lora_id] * + (prompt_len + if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. @@ -133,7 +166,10 @@ def _prepare_prompt( max_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long) - + lora_index_mapping = [ + _pad_to_max(mapping, max_prompt_len, pad=0) + for mapping in lora_index_mapping + ] input_metadata = InputMetadata( prompt_lens=prompt_lens, slot_mapping=slot_mapping, @@ -142,23 +178,32 @@ def _prepare_prompt( block_tables=None, use_cuda_graph=False, ) - return input_tokens, input_positions, input_metadata + return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], + Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] context_lens: List[int] = [] block_tables: List[List[int]] = [] + lora_index_mapping: List[List[int]] = [] + lora_prompt_mapping: List[int] = [] + lora_requests: Set[LoRARequest] = set() for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) + lora_id = seq_group_metadata.lora_int_id + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() @@ -177,6 +222,8 @@ def _prepare_decode( block_offset = position % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append([slot]) + lora_index_mapping.append([lora_id]) + lora_prompt_mapping.append(lora_id) if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window // @@ -229,7 +276,7 @@ def _prepare_decode( dtype=torch.int, device=device, pin_memory=pin_memory) - + if use_captured_graph: # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -246,6 +293,10 @@ def _prepare_decode( dtype=torch.int, ) + lora_index_mapping = [ + _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping + ] + input_metadata = InputMetadata( prompt_lens=[], slot_mapping=slot_mapping, @@ -254,7 +305,7 @@ def _prepare_decode( block_tables=block_tables, use_cuda_graph=use_captured_graph, ) - return input_tokens, input_positions, input_metadata + return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests def _prepare_sample( self, @@ -338,10 +389,20 @@ def execute_model( # Prepare input tensors. if is_prompt: inputs = self._prepare_prompt(seq_group_metadata_list) - input_tokens, input_positions, input_metadata = inputs + input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests = inputs else: inputs = self._prepare_decode(seq_group_metadata_list) - input_tokens, input_positions, input_metadata = inputs + input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests = inputs + + if self.lora_config: + flat_lora_index_mapping = [ + item for sublist in lora_index_mapping for item in sublist + ] + lora_mapping = LoRAMapping( + flat_lora_index_mapping, + lora_prompt_mapping, + ) + self.set_active_loras(lora_requests, lora_mapping) # Execute the model. if input_metadata.use_cuda_graph: @@ -374,6 +435,28 @@ def profile_run(self) -> None: max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests = [] + dummy_lora_requests_per_seq = [] + if self.lora_config: + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_id=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] + # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] @@ -387,6 +470,8 @@ def profile_run(self) -> None: seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, ) seqs.append(seq) @@ -397,6 +482,32 @@ def profile_run(self) -> None: torch.cuda.synchronize() return + def remove_all_loras(self) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_all_loras() + + def set_active_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_loras(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_loras() + @torch.inference_mode() def capture_model(self, kv_caches: List[KVCache]) -> None: assert not self.model_config.enforce_eager @@ -432,6 +543,13 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: use_cuda_graph=True, ) + if self.lora_config: + lora_mapping = LoRAMapping( + [0] * batch_size, + [0] * batch_size, + ) + self.set_active_loras(set(), lora_mapping) + graph_runner = CUDAGraphRunner(self.model) graph_runner.capture( input_tokens[:batch_size], diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 8698b15721507..bb8e7fd6cf86e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,18 +1,20 @@ """A GPU worker class.""" +import gc import os -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple, Set, Optional import torch import torch.distributed from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoRAConfig) from vllm.model_executor import set_random_seed from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner +from vllm.lora.request import LoRARequest class Worker: @@ -30,15 +32,17 @@ def __init__( scheduler_config: SchedulerConfig, rank: Optional[int] = None, distributed_init_method: Optional[str] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.rank = rank self.distributed_init_method = distributed_init_method + self.lora_config = lora_config self.model_runner = ModelRunner(model_config, parallel_config, - scheduler_config) + scheduler_config, lora_config) # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). self.cache_config = None @@ -107,6 +111,9 @@ def profile_num_available_blocks( num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + gc.collect() torch.cuda.empty_cache() return num_gpu_blocks, num_cpu_blocks @@ -160,6 +167,15 @@ def execute_model( self.gpu_cache) return output + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() + def _init_distributed_environment( parallel_config: ParallelConfig,