Skip to content

Commit

Permalink
perf: reduce the read and write of shared memory in the FusedAddRMSNo…
Browse files Browse the repository at this point in the history
…rmKernel (#592)

Use `vec_t<float, VEC_SIZE> x_vec` to reduce the number of read and
write operations to shared memory.
  • Loading branch information
Abatom authored Nov 9, 2024
1 parent 1058d1e commit 2043ca2
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 3 deletions.
55 changes: 55 additions & 0 deletions benchmarks/bench_fused_add_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import argparse
from typing import cast

import torch
from triton.testing import do_bench

import flashinfer

@torch.inference_mode()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--batch-sizes", nargs='+', type=int, default=[1, 19, 99, 989])
parser.add_argument("--hidden-sizes", nargs='+', type=int, default=[111, 500, 1024, 3072, 4096, 8192])
parser.add_argument("--dtypes", nargs='+', choices=["float16", "bfloat16"], default=["float16"])
args = parser.parse_args()

eps = 1e-6

# Loop over each combination of batch_size, hidden_size, and dtype
for batch_size in args.batch_sizes:
for hidden_size in args.hidden_sizes:
for dtype_str in args.dtypes:
dtype = getattr(torch, dtype_str)

# Define tensors with the correct dtype
x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda")
residual = torch.randn_like(x)
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")

@torch.cuda.nvtx.range(f"fused_add_rmsnorm batch_size={batch_size}, hidden_size={hidden_size}, dtype={dtype_str}")
def fn() -> None:
flashinfer.fused_add_rmsnorm(x, residual, weight, eps)

# Run benchmarking
latency_ms = cast(float, do_bench(fn))
throughput = (
(x.numel() * x.element_size() * 2
+ residual.numel() * residual.element_size() * 2
+ weight.numel() * weight.element_size())
/ (latency_ms * 1e-3)
)
print(
f"batch_size: {batch_size:3},",
f"hidden_size: {hidden_size:5},",
f"dtype: {dtype_str:8},",
f"latency: {latency_ms*1e3:2.0f}us,",
f"throughput: {throughput*1e-9:7.3f}GB/s",
)

print("---")

torch.cuda.profiler.stop()

if __name__ == "__main__":
main()
11 changes: 8 additions & 3 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
input_vec.fill(0.f);
vec_t<T, VEC_SIZE> residual_vec;
residual_vec.fill(0.f);
vec_t<float, VEC_SIZE> x_vec;
x_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
Expand All @@ -143,10 +145,11 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
x += float(residual_vec[j]);
sum_sq += x * x;
residual_vec[j] = (T)x;
smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j] = x;
x_vec[j] = x;
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
x_vec.store(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}

Expand Down Expand Up @@ -174,15 +177,17 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<float, VEC_SIZE> x_vec;
input_vec.fill(0.f);
weight_vec.fill(0.f);
x_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
x_vec.load(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
float x = smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j];
input_vec[j] = x * rms_rcp * float(weight_vec[j]);
input_vec[j] = x_vec[j] * rms_rcp * float(weight_vec[j]);
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
Expand Down

0 comments on commit 2043ca2

Please sign in to comment.