diff --git a/Dockerfile.rocm b/Dockerfile.rocm
index 36a7ee37fd228..23c962a3d7075 100644
--- a/Dockerfile.rocm
+++ b/Dockerfile.rocm
@@ -52,6 +52,7 @@ RUN pip install xformers==0.0.23 --no-deps
RUN cd /app \
&& cd vllm \
&& pip install -U -r requirements-rocm.txt \
+ && bash patch_rocm.rocm.sh \
&& bash patch_xformers.rocm.sh \
&& python3 setup.py install \
&& cd ..
diff --git a/README.md b/README.md
index 8ea4d029dc64f..85c5068b69cb9 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,20 @@
+multi-lora rocm development
+
+Derived from [Yard1's multi-lora branch](https://github.com/Yard1/vllm/tree/multi_lora)
+
+[Important note]
+
+Starting from ROCm v5.7, some type conversion functions on bfloat16 are included and implemented in header files. Unfortunately a few of the host functions are not specified as inline or static, so building the project on ROCm directly would result in ODR violations when the translation units are being linked.
+
+A way to circumvent this is to manually add the `inline` or `static` keywards to the related functions. In the `rocm/pytorch` container that `Dockerfile.rocm` builds from, it means adding the keyword `inline` to `/opt/rocm/include/hip/amd_detail/amd_hip_bf16.h:96` so that the line becomes
+
+```cpp
+L96: #define __HOST_DEVICE__ __host__ __device__ inline
+```
+
+This is far from a pretty solution though. Even though it appears that [ROCm is fixing this](https://github.com/ROCm/clr/commit/86bd518981b364c138f9901b28a529899d8654f3), it doesn't seem to be included in ROCm 6.0.0. Fixes like this may need to stay around until better solutions come out.
+
+
diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py
index e33d5fb2dc247..d75d690cc66d4 100644
--- a/benchmarks/benchmark_latency.py
+++ b/benchmarks/benchmark_latency.py
@@ -65,7 +65,9 @@ def run_to_completion(profile_dir: Optional[str] = None):
if args.profile:
profile_dir = args.profile_result_dir
if not profile_dir:
- profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
+ profile_dir = Path(
+ "."
+ ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=args.profile_result_dir)
return
@@ -123,9 +125,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
'--profile-result-dir',
type=str,
default=None,
- help=(
- 'path to save the pytorch profiler output. Can be visualized '
- 'with ui.perfetto.dev or Tensorboard.'
- ))
+ help=('path to save the pytorch profiler output. Can be visualized '
+ 'with ui.perfetto.dev or Tensorboard.'))
args = parser.parse_args()
main(args)
diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h
index aa58dd73c148a..1eef4c34607f0 100644
--- a/csrc/cuda_compat.h
+++ b/csrc/cuda_compat.h
@@ -18,6 +18,12 @@
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
#endif
+#ifndef USE_ROCM
+ #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta)
+#else
+ #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
+#endif
+
#ifndef USE_ROCM
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
diff --git a/csrc/punica/LICENSE b/csrc/punica/LICENSE
new file mode 100644
index 0000000000000..a46e2cdcadf7d
--- /dev/null
+++ b/csrc/punica/LICENSE
@@ -0,0 +1,217 @@
+Contains code from https://github.com/punica-ai/punica
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright {yyyy} {name of copyright owner}
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+------------------------------------------------------------------------------------
+
+This product bundles various third-party components under other open source licenses.
+This section summarizes those components and their licenses. See licenses/
+for text of these licenses.
+
+
+Apache-2.0
+* third_party/nvbench (with LLVM exception)
+* third_party/flashinfer
+
+BSD-3-Clause:
+* third_party/cutlass
\ No newline at end of file
diff --git a/csrc/punica/bgmv/bgmv_all.cu b/csrc/punica/bgmv/bgmv_all.cu
new file mode 100644
index 0000000000000..2502a67e3c813
--- /dev/null
+++ b/csrc/punica/bgmv/bgmv_all.cu
@@ -0,0 +1,21 @@
+#include "bgmv_config.h"
+#include "bgmv_impl.cuh"
+
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
+FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h
new file mode 100644
index 0000000000000..664dddc680ab6
--- /dev/null
+++ b/csrc/punica/bgmv/bgmv_config.h
@@ -0,0 +1,63 @@
+#pragma once
+
+template
+void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
+ const W_T *__restrict__ W,
+ const int64_t *__restrict__ indicies, int64_t y_offset,
+ int64_t full_y_size, int64_t batch_size, int64_t num_layers,
+ int64_t layer_idx, float scale);
+
+// clang-format off
+
+#define FOR_BGMV_WIDE_exc128(f, in_T, out_T, W_T, narrow) \
+ f(in_T, out_T, W_T, narrow, 256) \
+ f(in_T, out_T, W_T, narrow, 512) \
+ f(in_T, out_T, W_T, narrow, 1024) \
+ f(in_T, out_T, W_T, narrow, 1280) \
+ f(in_T, out_T, W_T, narrow, 1728) \
+ f(in_T, out_T, W_T, narrow, 1792) \
+ f(in_T, out_T, W_T, narrow, 2048) \
+ f(in_T, out_T, W_T, narrow, 2560) \
+ f(in_T, out_T, W_T, narrow, 2752) \
+ f(in_T, out_T, W_T, narrow, 3072) \
+ f(in_T, out_T, W_T, narrow, 3456) \
+ f(in_T, out_T, W_T, narrow, 3584) \
+ f(in_T, out_T, W_T, narrow, 4096) \
+ f(in_T, out_T, W_T, narrow, 5120) \
+ f(in_T, out_T, W_T, narrow, 5504) \
+ f(in_T, out_T, W_T, narrow, 6912) \
+ f(in_T, out_T, W_T, narrow, 7168) \
+ f(in_T, out_T, W_T, narrow, 8192) \
+ f(in_T, out_T, W_T, narrow, 9216) \
+ f(in_T, out_T, W_T, narrow, 10240) \
+ f(in_T, out_T, W_T, narrow, 11008) \
+ f(in_T, out_T, W_T, narrow, 12288) \
+ f(in_T, out_T, W_T, narrow, 13824) \
+ f(in_T, out_T, W_T, narrow, 14336) \
+ f(in_T, out_T, W_T, narrow, 16384) \
+ f(in_T, out_T, W_T, narrow, 20480) \
+ f(in_T, out_T, W_T, narrow, 28672) \
+ f(in_T, out_T, W_T, narrow, 32000) \
+ f(in_T, out_T, W_T, narrow, 32256) \
+ f(in_T, out_T, W_T, narrow, 32512) \
+ f(in_T, out_T, W_T, narrow, 32768) \
+ f(in_T, out_T, W_T, narrow, 33024) \
+ f(in_T, out_T, W_T, narrow, 36864) \
+ f(in_T, out_T, W_T, narrow, 49152) \
+// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
+
+#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
+ f(in_T, out_T, W_T, narrow, 128) \
+ FOR_BGMV_WIDE_exc128(f, in_T, out_T, W_T, narrow) \
+// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
+
+// Keep this in sync with vllm/config::LoRAConfig
+#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
+ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
+ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
+ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
+ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) \
+ FOR_BGMV_WIDE_exc128(f, in_T, out_T, W_T, 128)
+
+// clang-format on
diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh
new file mode 100644
index 0000000000000..2e72394647c6a
--- /dev/null
+++ b/csrc/punica/bgmv/bgmv_impl.cuh
@@ -0,0 +1,372 @@
+#pragma once
+
+#include
+#ifndef USE_ROCM
+#include
+#else
+#include
+#endif
+#ifndef USE_ROCM
+#include
+#endif
+#include
+#include
+#include
+
+#include "vec_dtypes.cuh"
+
+namespace cg = cooperative_groups;
+
+#ifdef USE_ROCM
+
+template
+__host__ __device__
+inline void* memcpy_blocking(void *dst, const void *src) {
+ // Does not handle the case of long datatypes
+ char *d = reinterpret_cast(dst);
+ const char *s = reinterpret_cast(src);
+ size_t i = 0;
+#pragma unroll
+ for (i = 0; i < len; ++i) {
+ d[i] = s[i];
+ }
+ return dst;
+}
+#endif
+
+// nthrs = (32, 4)
+template
+__global__ void
+bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
+ const W_T *__restrict__ W,
+ const int64_t *__restrict__ indicies, int64_t y_offset,
+ int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
+ float scale) {
+ size_t batch_idx = blockIdx.y;
+ int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
+ if (idx < 0) {
+ return;
+ }
+
+ auto block = cg::this_thread_block();
+ size_t j = blockIdx.x;
+ constexpr size_t num_pipeline_stages = 2;
+ constexpr size_t tile_size = tx * ty * vec_size;
+ __shared__ W_T W_shared[num_pipeline_stages * tile_size];
+ __shared__ in_T X_shared[num_pipeline_stages * tile_size];
+ __shared__ float y_warpwise[ty];
+
+ size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
+ size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
+#ifndef USE_ROCM
+ auto pipe = cuda::make_pipeline();
+
+ // pipeline load W/X and compute WX;
+ pipe.producer_acquire();
+ cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
+ W + (idx * feat_out + j) * feat_in +
+ (threadIdx.y * tx + threadIdx.x) * vec_size,
+ cuda::aligned_size_t(W_copy_size), pipe);
+ cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
+ X + (batch_idx * feat_in) +
+ (threadIdx.y * tx + threadIdx.x) * vec_size,
+ cuda::aligned_size_t(X_copy_size), pipe);
+ pipe.producer_commit();
+#else
+ memcpy_blocking(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
+ W + (idx * feat_out + j) * feat_in +
+ (threadIdx.y * tx + threadIdx.x) * vec_size);
+ memcpy_blocking(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
+ X + (batch_idx * feat_in) +
+ (threadIdx.y * tx + threadIdx.x) * vec_size);
+#endif
+ size_t copy_idx, compute_idx;
+ float y = 0.f;
+ vec_t x_vec;
+ vec_t w_vec;
+ size_t tile_idx;
+
+#pragma unroll
+ for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
+ ++tile_idx) {
+ copy_idx = tile_idx % num_pipeline_stages;
+ // pipeline stage: async copy W fragment
+#ifndef USE_ROCM
+ pipe.producer_acquire();
+ if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
+ cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
+ (threadIdx.y * tx + threadIdx.x) * vec_size,
+ W + (idx * feat_out + j) * feat_in +
+ tile_idx * tile_size +
+ (threadIdx.y * tx + threadIdx.x) * vec_size,
+ cuda::aligned_size_t(W_copy_size), pipe);
+ cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
+ (threadIdx.y * tx + threadIdx.x) * vec_size,
+ X + (batch_idx * feat_in) + tile_idx * tile_size +
+ (threadIdx.y * tx + threadIdx.x) * vec_size,
+ cuda::aligned_size_t(X_copy_size), pipe);
+ }
+ pipe.producer_commit();
+#else
+ if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
+ memcpy_blocking(W_shared + W_shared_offset[copy_idx] +
+ (threadIdx.y * tx + threadIdx.x) * vec_size,
+ W + (idx * feat_out + j) * feat_in +
+ tile_idx * tile_size +
+ (threadIdx.y * tx + threadIdx.x) * vec_size);
+ memcpy_blocking(X_shared + X_shared_offset[copy_idx] +
+ (threadIdx.y * tx + threadIdx.x) * vec_size,
+ X + (batch_idx * feat_in) + tile_idx * tile_size +
+ (threadIdx.y * tx + threadIdx.x) * vec_size);
+ }
+#endif
+
+ compute_idx = (tile_idx - 1) % num_pipeline_stages;
+ // pipeline stage: compute WX
+#ifndef USE_ROCM
+ pipe.consumer_wait();
+#endif
+ block.sync();
+ x_vec.load(X_shared + X_shared_offset[compute_idx] +
+ (threadIdx.y * tx + threadIdx.x) * vec_size);
+ w_vec.load(W_shared + W_shared_offset[compute_idx] +
+ (threadIdx.y * tx + threadIdx.x) * vec_size);
+ float sum = 0.f;
+#pragma unroll
+ for (size_t i = 0; i < vec_size; ++i) {
+#ifndef USE_ROCM
+ sum += float(w_vec[i]) * float(x_vec[i]) * scale;
+#else
+ sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale;
+#endif
+ }
+#pragma unroll
+ for (size_t offset = tx / 2; offset > 0; offset /= 2) {
+ sum += VLLM_SHFL_DOWN_SYNC(sum, offset);
+ }
+ y_warpwise[threadIdx.y] = sum;
+ block.sync();
+#pragma unroll
+ for (size_t i = 0; i < ty; ++i) {
+ y += y_warpwise[i];
+ }
+
+ block.sync();
+#ifndef USE_ROCM
+ pipe.consumer_release();
+#endif
+ }
+
+ compute_idx = (tile_idx - 1) % num_pipeline_stages;
+ // final pipeline stage
+#ifndef USE_ROCM
+ pipe.consumer_wait();
+#endif
+ block.sync();
+ x_vec.load(X_shared + X_shared_offset[compute_idx] +
+ (threadIdx.y * tx + threadIdx.x) * vec_size);
+ w_vec.load(W_shared + W_shared_offset[compute_idx] +
+ (threadIdx.y * tx + threadIdx.x) * vec_size);
+ float sum = 0.f;
+#pragma unroll
+ for (size_t i = 0; i < vec_size; ++i) {
+#ifndef USE_ROCM
+ sum += float(w_vec[i]) * float(x_vec[i]) * scale;
+#else
+ sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale;
+#endif
+ }
+#pragma unroll
+ for (size_t offset = tx / 2; offset > 0; offset /= 2) {
+ sum += VLLM_SHFL_DOWN_SYNC(sum, offset);
+ }
+ y_warpwise[threadIdx.y] =
+ ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
+ ? sum
+ : 0.f;
+ block.sync();
+#pragma unroll
+ for (size_t i = 0; i < ty; ++i) {
+ y += y_warpwise[i];
+ }
+
+ block.sync();
+#ifndef USE_ROCM
+ pipe.consumer_release();
+#endif
+
+ // write Y;
+ if (block.thread_rank() == 0) {
+#ifndef USE_ROCM
+ Y[batch_idx * full_y_size + y_offset + j] += static_cast(y);
+#else
+ size_t y_idx = batch_idx * full_y_size + y_offset + j;
+ Y[y_idx] = vllm_add(Y[y_idx], convert_type(y));
+#endif
+ }
+}
+
+// nthrs = (2, 16, 4)
+template
+__global__ void
+bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
+ const W_T *__restrict__ W,
+ const int64_t *__restrict__ indicies, int64_t y_offset,
+ int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
+ float scale) {
+ size_t batch_idx = blockIdx.y;
+ int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
+
+ if (idx < 0) {
+ return;
+ }
+
+ auto block = cg::this_thread_block();
+ size_t tile_idx = blockIdx.x;
+
+ // load X;
+ vec_t x_vec;
+ x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
+
+ // load W;
+ vec_t w_vec;
+ w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
+ block.thread_rank() * vec_size);
+
+ float sum = 0.f;
+#pragma unroll
+ for (size_t i = 0; i < vec_size; ++i) {
+#ifndef USE_ROCM
+ sum += float(w_vec[i]) * float(x_vec[i]) * scale;
+#else
+ sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale;
+#endif
+ }
+
+ cg::thread_block_tile g = cg::tiled_partition(block);
+#pragma unroll
+ for (size_t offset = tx / 2; offset > 0; offset /= 2) {
+ sum += g.shfl_down(sum, offset);
+ }
+ sum = g.shfl(sum, 0);
+
+ if (threadIdx.x == 0) {
+#ifndef USE_ROCM
+ Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
+ threadIdx.z * ty + threadIdx.y]
+ += static_cast(sum);
+#else
+ size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
+ threadIdx.z * ty + threadIdx.y;
+ Y[y_idx] = vllm_add(Y[y_idx], convert_type(sum));
+#endif
+ }
+}
+
+template
+void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
+ const W_T *__restrict__ W,
+ const int64_t *__restrict__ indicies, int64_t y_offset,
+ int64_t full_y_size, int64_t batch_size, int64_t num_layers,
+ int64_t layer_idx, float scale) {
+ constexpr size_t vec_size = 8;
+ constexpr int tz = 4;
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ if constexpr (feat_in < feat_out) {
+ static_assert(feat_in % vec_size == 0);
+ constexpr int tx = feat_in / vec_size;
+
+ static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
+ (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
+ (8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
+
+ if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
+ constexpr int ty = 32 / tx;
+ dim3 nblks(feat_out / (ty * tz), batch_size);
+ dim3 nthrs(tx, ty, tz);
+
+ bgmv_expand_kernel
+ <<>>(Y, X, W, indicies, y_offset,
+ full_y_size, num_layers, layer_idx,
+ scale);
+ } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
+ constexpr int ty = 16 / tx;
+ dim3 nblks(feat_out / (ty * tz), batch_size);
+ dim3 nthrs(tx, ty, tz);
+
+ bgmv_expand_kernel
+ <<>>(Y, X, W, indicies, y_offset,
+ full_y_size, num_layers, layer_idx,
+ scale);
+ } else {
+ constexpr int ty = 8 / tx;
+ dim3 nblks(feat_out / (ty * tz), batch_size);
+ dim3 nthrs(tx, ty, tz);
+
+ bgmv_expand_kernel
+ <<>>(Y, X, W, indicies, y_offset,
+ full_y_size, num_layers, layer_idx,
+ scale);
+ }
+ } else {
+ static_assert(feat_in % (vec_size * 32) == 0 ||
+ feat_in % (vec_size * 16) == 0 ||
+ feat_in % (vec_size * 8) == 0);
+
+ if constexpr (feat_in % (vec_size * 32) == 0) {
+ constexpr int tx = 32;
+ constexpr int ty = 4;
+
+ dim3 nblks(feat_out, batch_size);
+ dim3 nthrs(tx, ty);
+
+ bgmv_shrink_kernel
+ <<>>(Y, X, W, indicies, y_offset,
+ full_y_size, num_layers, layer_idx,
+ scale);
+ } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
+ constexpr int tx = 32;
+ constexpr int ty = 4;
+
+ dim3 nblks(feat_out, batch_size);
+ dim3 nthrs(tx, ty);
+
+ bgmv_shrink_kernel
+ <<>>(Y, X, W, indicies, y_offset,
+ full_y_size, num_layers, layer_idx,
+ scale);
+ } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
+ constexpr int tx = 16;
+ constexpr int ty = 4;
+
+ dim3 nblks(feat_out, batch_size);
+ dim3 nthrs(tx, ty);
+
+ bgmv_shrink_kernel
+ <<>>(Y, X, W, indicies, y_offset,
+ full_y_size, num_layers, layer_idx,
+ scale);
+ }
+ }
+}
+
+#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
+ template void bgmv_kernel( \
+ out_T * __restrict__ Y, const in_T *__restrict__ X, \
+ const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
+ int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
+ int64_t num_layers, int64_t layer_idx, float scale);
+
+#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
+ INST_BGMV(narrow, wide, in_T, out_T, W_T) \
+ INST_BGMV(wide, narrow, in_T, out_T, W_T)
diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh
new file mode 100644
index 0000000000000..2738892e6dc4a
--- /dev/null
+++ b/csrc/punica/bgmv/vec_dtypes.cuh
@@ -0,0 +1,1325 @@
+#ifndef VEC_DTYPES_CUH_
+#define VEC_DTYPES_CUH_
+
+#ifdef FLASHINFER_USE_FP8
+#include
+#endif
+#include
+
+#include
+
+#include "../type_convert.h"
+#include "../../cuda_compat.h"
+
+#define FLASHINFER_INLINE \
+ inline __attribute__((always_inline)) __device__ __host__
+
+template
+struct vec_t {
+ FLASHINFER_INLINE float_t &operator[](size_t i);
+ FLASHINFER_INLINE const float_t &operator[](size_t i) const;
+ FLASHINFER_INLINE void fill(float_t val);
+ FLASHINFER_INLINE void load(const float_t *ptr);
+ FLASHINFER_INLINE void store(float_t *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src);
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr);
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const;
+ FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src);
+};
+
+template
+FLASHINFER_INLINE void cast_from_impl(const vec_t &src,
+ vec_t &dst) {
+#pragma unroll
+ for (size_t i = 0; i < vec_size; ++i) {
+ dst[i] = tgt_float_t(src[i]);
+ }
+}
+
+template
+FLASHINFER_INLINE void cast_load_impl(const src_float_t *src_ptr,
+ vec_t &dst) {
+ if constexpr (std::is_same::value) {
+ dst.load(src_ptr);
+ } else {
+ vec_t tmp;
+ tmp.load(src_ptr);
+ dst.cast_from(tmp);
+ }
+}
+
+template
+FLASHINFER_INLINE void cast_store_impl(const vec_t &src,
+ tgt_float_t *dst_ptr) {
+ if constexpr (std::is_same::value) {
+ src.store(dst_ptr);
+ } else {
+ vec_t tmp;
+ tmp.cast_from(src);
+ tmp.store(dst_ptr);
+ }
+}
+
+#ifdef FLASHINFER_USE_FP8
+/******************* vec_t<__nv_fp8_e4m3> *******************/
+
+// __nv_fp8_e4m3 x 1
+template <>
+struct vec_t<__nv_fp8_e4m3, 1> {
+ __nv_fp8_e4m3 data;
+
+ FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
+ return ((__nv_fp8_e4m3 *)(&data))[i];
+ }
+ FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
+ return ((const __nv_fp8_e4m3 *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
+ FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
+ FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
+ const __nv_fp8_e4m3 *src);
+};
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) {
+ data = val;
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3 *ptr) {
+ data = *ptr;
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store(
+ __nv_fp8_e4m3 *ptr) const {
+ *ptr = data;
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy(
+ __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
+ *dst = *src;
+}
+
+// __nv_fp8_e4m3 x 2
+template <>
+struct vec_t<__nv_fp8_e4m3, 2> {
+ __nv_fp8x2_e4m3 data;
+
+ FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
+ return ((__nv_fp8_e4m3 *)(&data))[i];
+ }
+ FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
+ return ((const __nv_fp8_e4m3 *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
+ FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
+ FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
+ const __nv_fp8_e4m3 *src);
+};
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) {
+ data.__x =
+ (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3 *ptr) {
+ data = *((__nv_fp8x2_e4m3 *)ptr);
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store(
+ __nv_fp8_e4m3 *ptr) const {
+ *((__nv_fp8x2_e4m3 *)ptr) = data;
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy(
+ __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
+ *((__nv_fp8x2_e4m3 *)dst) = *((__nv_fp8x2_e4m3 *)src);
+}
+
+// __nv_fp8_e4m3 x 4
+
+template <>
+struct vec_t<__nv_fp8_e4m3, 4> {
+ __nv_fp8x4_e4m3 data;
+
+ FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
+ return ((__nv_fp8_e4m3 *)(&data))[i];
+ }
+ FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
+ return ((const __nv_fp8_e4m3 *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
+ FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
+ FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
+ const __nv_fp8_e4m3 *src);
+};
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) {
+ data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+ (__nv_fp8x4_storage_t(val.__x) << 16) |
+ (__nv_fp8x4_storage_t(val.__x) << 8) |
+ __nv_fp8x4_storage_t(val.__x);
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3 *ptr) {
+ data = *((__nv_fp8x4_e4m3 *)ptr);
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store(
+ __nv_fp8_e4m3 *ptr) const {
+ *((__nv_fp8x4_e4m3 *)ptr) = data;
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy(
+ __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
+ *((__nv_fp8x4_e4m3 *)dst) = *((__nv_fp8x4_e4m3 *)src);
+}
+
+// __nv_fp8_e4m3 x 8
+
+template <>
+struct vec_t<__nv_fp8_e4m3, 8> {
+ uint2 data;
+
+ FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
+ return ((__nv_fp8_e4m3 *)(&data))[i];
+ }
+ FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
+ return ((const __nv_fp8_e4m3 *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
+ FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
+ FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
+ const __nv_fp8_e4m3 *src);
+};
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) {
+ ((__nv_fp8x4_e4m3 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+ (__nv_fp8x4_storage_t(val.__x) << 16) |
+ (__nv_fp8x4_storage_t(val.__x) << 8) |
+ __nv_fp8x4_storage_t(val.__x);
+ ((__nv_fp8x4_e4m3 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+ (__nv_fp8x4_storage_t(val.__x) << 16) |
+ (__nv_fp8x4_storage_t(val.__x) << 8) |
+ __nv_fp8x4_storage_t(val.__x);
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3 *ptr) {
+ data = *((uint2 *)ptr);
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store(
+ __nv_fp8_e4m3 *ptr) const {
+ *((uint2 *)ptr) = data;
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy(
+ __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
+ *((__nv_fp8_e4m3 *)dst) = *((__nv_fp8_e4m3 *)src);
+}
+
+// __nv_fp8_e4m3 x 16 or more
+template
+struct vec_t<__nv_fp8_e4m3, vec_size> {
+ uint4 data[vec_size / 16];
+
+ FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
+ return ((__nv_fp8_e4m3 *)data)[i];
+ }
+ FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
+ return ((const __nv_fp8_e4m3 *)data)[i];
+ }
+ FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 16; ++i) {
+ ((__nv_fp8x4_e4m3 *)(&(data[i].x)))->__x =
+ (__nv_fp8x4_storage_t(val.__x) << 24) |
+ (__nv_fp8x4_storage_t(val.__x) << 16) |
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+ ((__nv_fp8x4_e4m3 *)(&(data[i].y)))->__x =
+ (__nv_fp8x4_storage_t(val.__x) << 24) |
+ (__nv_fp8x4_storage_t(val.__x) << 16) |
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+ ((__nv_fp8x4_e4m3 *)(&(data[i].z)))->__x =
+ (__nv_fp8x4_storage_t(val.__x) << 24) |
+ (__nv_fp8x4_storage_t(val.__x) << 16) |
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+ ((__nv_fp8x4_e4m3 *)(&(data[i].w)))->__x =
+ (__nv_fp8x4_storage_t(val.__x) << 24) |
+ (__nv_fp8x4_storage_t(val.__x) << 16) |
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+ }
+ }
+ FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr) {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 16; ++i) {
+ data[i] = ((uint4 *)ptr)[i];
+ }
+ }
+ FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 16; ++i) {
+ ((uint4 *)ptr)[i] = data[i];
+ }
+ }
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
+ const __nv_fp8_e4m3 *src) {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 16; ++i) {
+ ((uint4 *)dst)[i] = ((uint4 *)src)[i];
+ }
+ }
+};
+
+/******************* vec_t<__nv_fp8_e5m2> *******************/
+
+// __nv_fp8_e5m2 x 1
+template <>
+struct vec_t<__nv_fp8_e5m2, 1> {
+ __nv_fp8_e5m2 data;
+
+ FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
+ return ((__nv_fp8_e5m2 *)(&data))[i];
+ }
+ FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
+ return ((const __nv_fp8_e5m2 *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
+ FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
+ FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
+ const __nv_fp8_e5m2 *src);
+};
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) {
+ data = val;
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2 *ptr) {
+ data = *ptr;
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store(
+ __nv_fp8_e5m2 *ptr) const {
+ *ptr = data;
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy(
+ __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
+ *dst = *src;
+}
+
+// __nv_fp8_e5m2 x 2
+template <>
+struct vec_t<__nv_fp8_e5m2, 2> {
+ __nv_fp8x2_e5m2 data;
+
+ FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
+ return ((__nv_fp8_e5m2 *)(&data))[i];
+ }
+ FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
+ return ((const __nv_fp8_e5m2 *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
+ FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
+ FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
+ const __nv_fp8_e5m2 *src);
+};
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) {
+ data.__x =
+ (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2 *ptr) {
+ data = *((__nv_fp8x2_e5m2 *)ptr);
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store(
+ __nv_fp8_e5m2 *ptr) const {
+ *((__nv_fp8x2_e5m2 *)ptr) = data;
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy(
+ __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
+ *((__nv_fp8x2_e5m2 *)dst) = *((__nv_fp8x2_e5m2 *)src);
+}
+
+// __nv_fp8_e5m2 x 4
+
+template <>
+struct vec_t<__nv_fp8_e5m2, 4> {
+ __nv_fp8x4_e5m2 data;
+
+ FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
+ return ((__nv_fp8_e5m2 *)(&data))[i];
+ }
+ FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
+ return ((const __nv_fp8_e5m2 *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
+ FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
+ FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
+ const __nv_fp8_e5m2 *src);
+};
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) {
+ data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+ (__nv_fp8x4_storage_t(val.__x) << 16) |
+ (__nv_fp8x4_storage_t(val.__x) << 8) |
+ __nv_fp8x4_storage_t(val.__x);
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2 *ptr) {
+ data = *((__nv_fp8x4_e5m2 *)ptr);
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store(
+ __nv_fp8_e5m2 *ptr) const {
+ *((__nv_fp8x4_e5m2 *)ptr) = data;
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy(
+ __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
+ *((__nv_fp8x4_e5m2 *)dst) = *((__nv_fp8x4_e5m2 *)src);
+}
+
+// __nv_fp8_e5m2 x 8
+
+template <>
+struct vec_t<__nv_fp8_e5m2, 8> {
+ uint2 data;
+
+ FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
+ return ((__nv_fp8_e5m2 *)(&data))[i];
+ }
+ FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
+ return ((const __nv_fp8_e5m2 *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
+ FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
+ FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
+ const __nv_fp8_e5m2 *src);
+};
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) {
+ ((__nv_fp8x4_e5m2 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+ (__nv_fp8x4_storage_t(val.__x) << 16) |
+ (__nv_fp8x4_storage_t(val.__x) << 8) |
+ __nv_fp8x4_storage_t(val.__x);
+ ((__nv_fp8x4_e5m2 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+ (__nv_fp8x4_storage_t(val.__x) << 16) |
+ (__nv_fp8x4_storage_t(val.__x) << 8) |
+ __nv_fp8x4_storage_t(val.__x);
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2 *ptr) {
+ data = *((uint2 *)ptr);
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store(
+ __nv_fp8_e5m2 *ptr) const {
+ *((uint2 *)ptr) = data;
+}
+
+FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy(
+ __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
+ *((__nv_fp8_e5m2 *)dst) = *((__nv_fp8_e5m2 *)src);
+}
+
+// __nv_fp8_e5m2 x 16 or more
+
+template
+struct vec_t<__nv_fp8_e5m2, vec_size> {
+ uint4 data[vec_size / 16];
+
+ FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
+ return ((__nv_fp8_e5m2 *)data)[i];
+ }
+ FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
+ return ((const __nv_fp8_e5m2 *)data)[i];
+ }
+ FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 16; ++i) {
+ ((__nv_fp8x4_e5m2 *)(&(data[i].x)))->__x =
+ (__nv_fp8x4_storage_t(val.__x) << 24) |
+ (__nv_fp8x4_storage_t(val.__x) << 16) |
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+ ((__nv_fp8x4_e5m2 *)(&(data[i].y)))->__x =
+ (__nv_fp8x4_storage_t(val.__x) << 24) |
+ (__nv_fp8x4_storage_t(val.__x) << 16) |
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+ ((__nv_fp8x4_e5m2 *)(&(data[i].z)))->__x =
+ (__nv_fp8x4_storage_t(val.__x) << 24) |
+ (__nv_fp8x4_storage_t(val.__x) << 16) |
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+ ((__nv_fp8x4_e5m2 *)(&(data[i].w)))->__x =
+ (__nv_fp8x4_storage_t(val.__x) << 24) |
+ (__nv_fp8x4_storage_t(val.__x) << 16) |
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+ }
+ }
+ FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr) {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 16; ++i) {
+ data[i] = ((uint4 *)ptr)[i];
+ }
+ }
+ FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 16; ++i) {
+ ((uint4 *)ptr)[i] = data[i];
+ }
+ }
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
+ const __nv_fp8_e5m2 *src) {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 16; ++i) {
+ ((uint4 *)dst)[i] = ((uint4 *)src)[i];
+ }
+ }
+};
+#endif
+
+/******************* vec_t *******************/
+
+// half x 1
+template <>
+struct vec_t {
+ half data;
+
+ FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; }
+ FLASHINFER_INLINE const half &operator[](size_t i) const {
+ return ((const half *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(half val);
+ FLASHINFER_INLINE void load(const half *ptr);
+ FLASHINFER_INLINE void store(half *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(half *dst, const half *src);
+};
+
+FLASHINFER_INLINE void vec_t::fill(half val) { data = val; }
+
+FLASHINFER_INLINE void vec_t::load(const half *ptr) { data = *ptr; }
+
+FLASHINFER_INLINE void vec_t::store(half *ptr) const { *ptr = data; }
+
+FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) {
+ *dst = *src;
+}
+
+// half x 2
+template <>
+struct vec_t {
+ half2 data;
+
+ FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; }
+ FLASHINFER_INLINE const half &operator[](size_t i) const {
+ return ((const half *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(half val);
+ FLASHINFER_INLINE void load(const half *ptr);
+ FLASHINFER_INLINE void store(half *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(half *dst, const half *src);
+};
+
+FLASHINFER_INLINE void vec_t::fill(half val) {
+ data = make_half2(val, val);
+}
+
+FLASHINFER_INLINE void vec_t::load(const half *ptr) {
+ data = *((half2 *)ptr);
+}
+
+FLASHINFER_INLINE void vec_t::store(half *ptr) const {
+ *((half2 *)ptr) = data;
+}
+
+FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) {
+ *((half2 *)dst) = *((half2 *)src);
+}
+
+// half x 4
+
+template <>
+struct vec_t {
+ uint2 data;
+
+ FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; }
+ FLASHINFER_INLINE const half &operator[](size_t i) const {
+ return ((const half *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(half val);
+ FLASHINFER_INLINE void load(const half *ptr);
+ FLASHINFER_INLINE void store(half *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(half *dst, const half *src);
+};
+
+FLASHINFER_INLINE void vec_t::fill(half val) {
+ *(half2 *)(&data.x) = make_half2(val, val);
+ *(half2 *)(&data.y) = make_half2(val, val);
+}
+
+FLASHINFER_INLINE void vec_t::load(const half *ptr) {
+ data = *((uint2 *)ptr);
+}
+
+FLASHINFER_INLINE void vec_t::store(half *ptr) const {
+ *((uint2 *)ptr) = data;
+}
+
+FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) {
+ *((uint2 *)dst) = *((uint2 *)src);
+}
+
+// half x 8 or more
+
+template
+struct vec_t {
+ uint4 data[vec_size / 8];
+ FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; }
+ FLASHINFER_INLINE const half &operator[](size_t i) const {
+ return ((const half *)data)[i];
+ }
+ FLASHINFER_INLINE void fill(half val) {
+#pragma unroll
+ for (size_t i = 0; i < vec_size; ++i) {
+ *(half2 *)(&(data[i].x)) = make_half2(val, val);
+ *(half2 *)(&(data[i].y)) = make_half2(val, val);
+ *(half2 *)(&(data[i].z)) = make_half2(val, val);
+ *(half2 *)(&(data[i].w)) = make_half2(val, val);
+ }
+ }
+ FLASHINFER_INLINE void load(const half *ptr) {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 8; ++i) {
+ data[i] = ((uint4 *)ptr)[i];
+ }
+ }
+ FLASHINFER_INLINE void store(half *ptr) const {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 8; ++i) {
+ ((uint4 *)ptr)[i] = data[i];
+ }
+ }
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(half *dst, const half *src) {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 8; ++i) {
+ ((uint4 *)dst)[i] = ((uint4 *)src)[i];
+ }
+ }
+};
+
+/******************* vec_t *******************/
+
+// nv_bfloat16 x 1
+template <>
+struct vec_t {
+ nv_bfloat16 data;
+
+ FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
+ return ((nv_bfloat16 *)(&data))[i];
+ }
+ FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
+ return ((const nv_bfloat16 *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(nv_bfloat16 val);
+ FLASHINFER_INLINE void load(const nv_bfloat16 *ptr);
+ FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
+ const nv_bfloat16 *src);
+};
+
+FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) {
+ data = val;
+}
+
+FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) {
+ data = *ptr;
+}
+
+FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const {
+ *ptr = data;
+}
+
+FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst,
+ const nv_bfloat16 *src) {
+ *dst = *src;
+}
+
+// nv_bfloat16 x 2
+template <>
+struct vec_t {
+ nv_bfloat162 data;
+
+ FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
+ return ((nv_bfloat16 *)(&data))[i];
+ }
+ FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
+ return ((const nv_bfloat16 *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(nv_bfloat16 val);
+ FLASHINFER_INLINE void load(const nv_bfloat16 *ptr);
+ FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
+ const nv_bfloat16 *src);
+};
+
+FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) {
+ data = make_bfloat162(val, val);
+}
+
+FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) {
+ data = *((nv_bfloat162 *)ptr);
+}
+
+FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const {
+ *((nv_bfloat162 *)ptr) = data;
+}
+
+FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst,
+ const nv_bfloat16 *src) {
+ *((nv_bfloat162 *)dst) = *((nv_bfloat162 *)src);
+}
+
+// nv_bfloat16 x 4
+
+template <>
+struct vec_t {
+ uint2 data;
+
+ FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
+ return ((nv_bfloat16 *)(&data))[i];
+ }
+ FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
+ return ((const nv_bfloat16 *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(nv_bfloat16 val);
+ FLASHINFER_INLINE void load(const nv_bfloat16 *ptr);
+ FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
+ const nv_bfloat16 *src);
+};
+
+FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) {
+ *(nv_bfloat162 *)(&data.x) = make_bfloat162(val, val);
+ *(nv_bfloat162 *)(&data.y) = make_bfloat162(val, val);
+}
+
+FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) {
+ data = *((uint2 *)ptr);
+}
+
+FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const {
+ *((uint2 *)ptr) = data;
+}
+
+FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst,
+ const nv_bfloat16 *src) {
+ *((uint2 *)dst) = *((uint2 *)src);
+}
+
+// nv_bfloat16 x 8 or more
+
+template
+struct vec_t {
+ uint4 data[vec_size / 8];
+
+ FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
+ return ((nv_bfloat16 *)data)[i];
+ }
+ FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
+ return ((const nv_bfloat16 *)data)[i];
+ }
+ FLASHINFER_INLINE void fill(nv_bfloat16 val) {
+#pragma unoll
+ for (size_t i = 0; i < vec_size / 8; ++i) {
+ *(nv_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val);
+ *(nv_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val);
+ *(nv_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val);
+ *(nv_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val);
+ }
+ }
+ FLASHINFER_INLINE void load(const nv_bfloat16 *ptr) {
+#pragma unoll
+ for (size_t i = 0; i < vec_size / 8; ++i) {
+ data[i] = ((uint4 *)ptr)[i];
+ }
+ }
+ FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const {
+#pragma unoll
+ for (size_t i = 0; i < vec_size / 8; ++i) {
+ ((uint4 *)ptr)[i] = data[i];
+ }
+ }
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
+ const nv_bfloat16 *src) {
+#pragma unoll
+ for (size_t i = 0; i < vec_size / 8; ++i) {
+ ((uint4 *)dst)[i] = ((uint4 *)src)[i];
+ }
+ }
+};
+
+/******************* vec_t *******************/
+
+// float x 1
+
+template <>
+struct vec_t {
+ float data;
+
+ FLASHINFER_INLINE float &operator[](size_t i) {
+ return ((float *)(&data))[i];
+ }
+ FLASHINFER_INLINE const float &operator[](size_t i) const {
+ return ((const float *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(float val);
+ FLASHINFER_INLINE void load(const float *ptr);
+ FLASHINFER_INLINE void store(float *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+
+ FLASHINFER_INLINE static void memcpy(float *dst, const float *src);
+};
+
+FLASHINFER_INLINE void vec_t::fill(float val) { data = val; }
+
+FLASHINFER_INLINE void vec_t::load(const float *ptr) { data = *ptr; }
+
+FLASHINFER_INLINE void vec_t::store(float *ptr) const { *ptr = data; }
+
+FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) {
+ *dst = *src;
+}
+
+// float x 2
+
+template <>
+struct vec_t {
+ float2 data;
+
+ FLASHINFER_INLINE float &operator[](size_t i) {
+ return ((float *)(&data))[i];
+ }
+ FLASHINFER_INLINE const float &operator[](size_t i) const {
+ return ((const float *)(&data))[i];
+ }
+ FLASHINFER_INLINE void fill(float val);
+ FLASHINFER_INLINE void load(const float *ptr);
+ FLASHINFER_INLINE void store(float *ptr) const;
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+ FLASHINFER_INLINE static void memcpy(float *dst, const float *src);
+};
+
+FLASHINFER_INLINE void vec_t::fill(float val) {
+ data = make_float2(val, val);
+}
+
+FLASHINFER_INLINE void vec_t::load(const float *ptr) {
+ data = *((float2 *)ptr);
+}
+
+FLASHINFER_INLINE void vec_t::store(float *ptr) const {
+ *((float2 *)ptr) = data;
+}
+
+FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) {
+ *((float2 *)dst) = *((float2 *)src);
+}
+
+// float x 4 or more
+template
+struct vec_t {
+ float4 data[vec_size / 4];
+
+ FLASHINFER_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; }
+ FLASHINFER_INLINE const float &operator[](size_t i) const {
+ return ((const float *)(data))[i];
+ }
+ FLASHINFER_INLINE void fill(float val) {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 4; ++i) {
+ data[i] = make_float4(val, val, val, val);
+ }
+ }
+ FLASHINFER_INLINE void load(const float *ptr) {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 4; ++i) {
+ data[i] = ((float4 *)ptr)[i];
+ }
+ }
+ FLASHINFER_INLINE void store(float *ptr) const {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 4; ++i) {
+ ((float4 *)ptr)[i] = data[i];
+ }
+ }
+ template
+ FLASHINFER_INLINE void cast_from(const vec_t &src) {
+ cast_from_impl(src, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_load(const T *ptr) {
+ cast_load_impl(ptr, *this);
+ }
+ template
+ FLASHINFER_INLINE void cast_store(T *ptr) const {
+ cast_store_impl(*this, ptr);
+ }
+ FLASHINFER_INLINE static void memcpy(float *dst, const float *src) {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 4; ++i) {
+ ((float4 *)dst)[i] = ((float4 *)src)[i];
+ }
+ }
+};
+
+/******************* vec_t type cast *******************/
+
+template
+FLASHINFER_INLINE void cast_from_impl(const vec_t &src,
+ vec_t &dst) {
+ if constexpr (vec_size == 1) {
+ dst.data = float(src.data);
+ } else {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 2; ++i) {
+ ((float2 *)(&dst.data))[i] = __half22float2(((half2 *)(&src.data))[i]);
+ }
+ }
+}
+
+template
+FLASHINFER_INLINE void cast_from_impl(const vec_t &src,
+ vec_t &dst) {
+ if constexpr (vec_size == 1) {
+ dst.data = half(src.data);
+ } else {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 2; ++i) {
+ ((half2 *)(&dst.data))[i] = __float22half2_rn(((float2 *)(&src.data))[i]);
+ }
+ }
+}
+
+template
+FLASHINFER_INLINE void cast_from_impl(const vec_t &src,
+ vec_t &dst) {
+ if constexpr (vec_size == 1) {
+ dst.data = float(src.data);
+ } else {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 2; ++i) {
+ ((float2 *)(&dst.data))[i] =
+ __bfloat1622float2(((nv_bfloat162 *)(&src.data))[i]);
+ }
+ }
+}
+
+template
+FLASHINFER_INLINE void cast_from_impl(const vec_t &src,
+ vec_t &dst) {
+ if constexpr (vec_size == 1) {
+ dst.data = nv_bfloat16(src.data);
+ } else {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 2; ++i) {
+ ((nv_bfloat162 *)(&dst.data))[i] =
+ __float22bfloat162_rn(((float2 *)(&src.data))[i]);
+ }
+ }
+}
+
+#ifdef FLASHINFER_USE_FP8
+
+template
+FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src,
+ vec_t &dst) {
+ if constexpr (vec_size == 1) {
+ dst.data = float(src.data);
+ } else if constexpr (vec_size == 2) {
+ *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e4m3 *)(&src.data));
+ } else {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 4; ++i) {
+ ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3 *)(&src.data))[i]);
+ }
+ }
+}
+
+template
+FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src,
+ vec_t &dst) {
+ if constexpr (vec_size == 1) {
+ dst.data = float(src.data);
+ } else {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 2; ++i) {
+ ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3 *)(&src.data))[i]);
+ }
+ }
+}
+
+template
+FLASHINFER_INLINE void cast_from_impl(const vec_t &src,
+ vec_t<__nv_fp8_e4m3, vec_size> &dst) {
+ if constexpr (vec_size == 1) {
+ dst.data = __nv_fp8_e4m3(src.data);
+ } else if constexpr (vec_size == 2) {
+ *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(float2 *)(&src.data));
+ } else {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 4; ++i) {
+ ((__nv_fp8x4_e4m3 *)(&dst.data))[i] =
+ __nv_fp8x4_e4m3(((float4 *)(&src.data))[i]);
+ }
+ }
+}
+
+template
+FLASHINFER_INLINE void cast_from_impl(const vec_t &src,
+ vec_t<__nv_fp8_e4m3, vec_size> &dst) {
+ if constexpr (vec_size == 1) {
+ dst.data = __nv_fp8_e4m3(src.data);
+ } else if constexpr (vec_size == 2) {
+ *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(half2 *)(&src.data));
+ } else {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 4; ++i) {
+ // NOTE(Zihao): need to double check if we properly handle flo and fhi
+ ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = __nv_fp8x4_e4m3(
+ ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]);
+ }
+ }
+}
+
+template
+FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src,
+ vec_t &dst) {
+ if constexpr (vec_size == 1) {
+ dst.data = float(src.data);
+ } else if constexpr (vec_size == 2) {
+ *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e5m2 *)(&src.data));
+ } else {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 4; ++i) {
+ ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2 *)(&src.data))[i]);
+ }
+ }
+}
+
+template
+FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src,
+ vec_t &dst) {
+ if constexpr (vec_size == 1) {
+ dst.data = float(src.data);
+ } else {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 2; ++i) {
+ ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2 *)(&src.data))[i]);
+ }
+ }
+}
+
+template
+FLASHINFER_INLINE void cast_from_impl(const vec_t &src,
+ vec_t<__nv_fp8_e5m2, vec_size> &dst) {
+ if constexpr (vec_size == 1) {
+ dst.data = __nv_fp8_e5m2(src.data);
+ } else if constexpr (vec_size == 2) {
+ *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(float2 *)(&src.data));
+ } else {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 4; ++i) {
+ ((__nv_fp8x4_e5m2 *)(&dst.data))[i] =
+ __nv_fp8x4_e5m2(((float4 *)(&src.data))[i]);
+ }
+ }
+}
+
+template
+FLASHINFER_INLINE void cast_from_impl(const vec_t &src,
+ vec_t<__nv_fp8_e5m2, vec_size> &dst) {
+ if constexpr (vec_size == 1) {
+ dst.data = __nv_fp8_e4m3(src.data);
+ } else if constexpr (vec_size == 2) {
+ *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(half2 *)(&src.data));
+ } else {
+#pragma unroll
+ for (size_t i = 0; i < vec_size / 4; ++i) {
+ // NOTE(Zihao): need to double check if we properly handle flo and fhi
+ ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = __nv_fp8x4_e5m2(
+ ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]);
+ }
+ }
+}
+
+#endif // FLASHINFER_USE_FP8
+
+#endif // VEC_DTYPES_CUH_
diff --git a/csrc/punica/punica_ops.cu b/csrc/punica/punica_ops.cu
new file mode 100644
index 0000000000000..644740d9c49b0
--- /dev/null
+++ b/csrc/punica/punica_ops.cu
@@ -0,0 +1,550 @@
+#include
+#include
+
+#include
+
+#include "type_convert.h"
+#include "../cuda_compat.h"
+#include "bgmv/bgmv_config.h"
+
+//====== utils ======
+
+inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
+ const char *a_name, const char *b_name) {
+ TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
+ a.dim(), " vs ", b.dim());
+ for (int i = 0; i < a.dim(); ++i) {
+ TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
+ ".size(", i, ")");
+ }
+}
+
+inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
+ return (uint32_t(a) << 16) | uint32_t(b);
+}
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
+
+#define CHECK_CONTIGUOUS(x) \
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+
+#define CHECK_INPUT(x) \
+ CHECK_CUDA(x); \
+ CHECK_CONTIGUOUS(x)
+
+#define CHECK_DIM(d, x) \
+ TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
+
+#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
+
+#define CHECK_EQ(a, b) \
+ TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
+
+//====== bgmv ======
+
+template
+inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
+ const int64_t *lora_indices,
+ uint16_t in_features, uint16_t out_features,
+ int64_t y_offset, int64_t full_y_size,
+ int64_t batch_size, int64_t num_layers,
+ int64_t layer_idx, float scale) {
+ switch (pack_u16(in_features, out_features)) {
+#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
+ case pack_u16(feat_in, feat_out): \
+ bgmv_kernel(Y, X, W, lora_indices, y_offset, \
+ full_y_size, batch_size, num_layers, \
+ layer_idx, scale); \
+ break;
+#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
+ CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
+ CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
+
+ FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
+#undef CASE
+#undef CASE_ONESIDE
+ default:
+ return false;
+ }
+
+ return true;
+}
+
+void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
+ torch::Tensor indicies, int64_t layer_idx, float scale) {
+ CHECK_INPUT(y);
+ CHECK_INPUT(x);
+ CHECK_INPUT(w);
+ CHECK_INPUT(indicies);
+
+ CHECK_DIM(2, y);
+ CHECK_DIM(2, x);
+ CHECK_DIM(4, w);
+ CHECK_DIM(1, indicies);
+
+ int64_t B = x.size(0);
+ int64_t h_in = x.size(1);
+ int64_t h_out = y.size(1);
+ int64_t num_layers = w.size(1);
+ CHECK_EQ(w.size(3), h_in);
+ CHECK_EQ(w.size(2), h_out);
+ CHECK_EQ(indicies.size(0), x.size(0));
+ CHECK_EQ(y.size(0), x.size(0));
+ bool ok = false;
+ if (h_in < 65536 && h_out < 65536) {
+ // TODO: See if we can get rid of this massive nested switch
+ switch (x.scalar_type()) {
+ case at::ScalarType::Half:
+ switch (y.scalar_type()) {
+ case at::ScalarType::Half:
+ switch (w.scalar_type()) {
+ case at::ScalarType::Half:
+ ok = launch_bgmv_kernel(static_cast(y.data_ptr()),
+ static_cast(x.data_ptr()),
+ static_cast(w.data_ptr()),
+ indicies.data_ptr(), h_in, h_out, 0,
+ h_out, B, num_layers, layer_idx, scale);
+ break;
+ case at::ScalarType::BFloat16:
+ ok = launch_bgmv_kernel(static_cast(y.data_ptr()),
+ static_cast(x.data_ptr()),
+ static_cast(w.data_ptr()),
+ indicies.data_ptr(), h_in, h_out, 0,
+ h_out, B, num_layers, layer_idx, scale);
+ break;
+ default:
+ break;
+ }
+ break;
+ case at::ScalarType::BFloat16:
+ switch (w.scalar_type()) {
+ case at::ScalarType::Half:
+ ok = launch_bgmv_kernel(static_cast(y.data_ptr()),
+ static_cast(x.data_ptr()),
+ static_cast(w.data_ptr()),
+ indicies.data_ptr(), h_in, h_out, 0,
+ h_out, B, num_layers, layer_idx, scale);
+ break;
+ case at::ScalarType::BFloat16:
+ ok = launch_bgmv_kernel(static_cast