Skip to content

Commit

Permalink
[RMSNorm][Kernel] Add FLOAT2/HALF2_VARIANCE macro (#29)
Browse files Browse the repository at this point in the history
* Update elementwise.cu

* Update layer_norm.cu

* Update relu.cu

* Update rms_norm.cu

* Update rms_norm.cu

* Update rms_norm.py

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update rms_norm.cu

* Update rms_norm.py

* Update README.md

* Update rms_norm.cu
  • Loading branch information
DefTruth authored Sep 14, 2024
1 parent cea328b commit 5601f72
Showing 1 changed file with 37 additions and 73 deletions.
110 changes: 37 additions & 73 deletions rms-norm/rms_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ __global__ void rms_norm_f16x2_f16_kernel(half* x, half* y, float g, int N, int
if (idx < N * K) HALF2(y[idx]) = reg_y;
}

#define HALF2_VARIANCE(reg, i) \
(((idx + (i)) < N * K) ? ((reg).x * (reg).x + (reg).y * (reg).y) : __float2half(0.0f))

#define FLOAT2_VARIANCE(reg, i) \
(((idx + (i)) < N * K) ? ((reg).x * (reg).x + (reg).y * (reg).y) : 0.0f)

template<const int NUM_THREADS=256>
__global__ void rms_norm_f16x8_f16_kernel(half* x, half* y, float g, int N, int K) {
int tid = threadIdx.x; // 0..K-1
Expand All @@ -209,18 +215,10 @@ __global__ void rms_norm_f16x8_f16_kernel(half* x, half* y, float g, int N, int
half2 reg_x_1 = HALF2(x[idx + 2]);
half2 reg_x_2 = HALF2(x[idx + 4]);
half2 reg_x_3 = HALF2(x[idx + 6]);
half variance = (((idx + 0) < N * K) ? (reg_x_0.x * reg_x_0.x
+ reg_x_0.y * reg_x_0.y)
: __float2half(0.0f));
variance += (((idx + 2) < N * K) ? (reg_x_1.x * reg_x_1.x
+ reg_x_1.y * reg_x_1.y)
: __float2half(0.0f));
variance += (((idx + 4) < N * K) ? (reg_x_2.x * reg_x_2.x
+ reg_x_2.y * reg_x_2.y)
: __float2half(0.0f));
variance += (((idx + 6) < N * K) ? (reg_x_3.x * reg_x_3.x
+ reg_x_3.y * reg_x_3.y)
: __float2half(0.0f));
half variance = HALF2_VARIANCE(reg_x_0, 0);
variance += HALF2_VARIANCE(reg_x_1, 2);
variance += HALF2_VARIANCE(reg_x_2, 4);
variance += HALF2_VARIANCE(reg_x_3, 6);
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
if (tid == 0) s_variance = hrsqrt(variance / (K_ + epsilon));
// wait for s_variance in shared memory to be ready for all threads
Expand Down Expand Up @@ -259,18 +257,12 @@ __global__ void rms_norm_f16x8_f32_kernel(half* x, half* y, float g, int N, int
float2 reg_x_1 = __half22float2(HALF2(x[idx + 2]));
float2 reg_x_2 = __half22float2(HALF2(x[idx + 4]));
float2 reg_x_3 = __half22float2(HALF2(x[idx + 6]));
float variance = (((idx + 0) < N * K) ? (reg_x_0.x * reg_x_0.x
+ reg_x_0.y * reg_x_0.y)
: 0.0f);
variance += (((idx + 2) < N * K) ? (reg_x_1.x * reg_x_1.x
+ reg_x_1.y * reg_x_1.y)
: 0.0f);
variance += (((idx + 4) < N * K) ? (reg_x_2.x * reg_x_2.x
+ reg_x_2.y * reg_x_2.y)
: 0.0f);
variance += (((idx + 6) < N * K) ? (reg_x_3.x * reg_x_3.x
+ reg_x_3.y * reg_x_3.y)
: 0.0f);

float variance = FLOAT2_VARIANCE(reg_x_0, 0);
variance += FLOAT2_VARIANCE(reg_x_1, 2);
variance += FLOAT2_VARIANCE(reg_x_2, 4);
variance += FLOAT2_VARIANCE(reg_x_3, 6);

variance = block_reduce_sum_f32<NUM_THREADS>(variance);
if (tid == 0) s_variance = rsqrtf(variance / ((float) K + epsilon));
// wait for s_variance in shared memory to be ready for all threads
Expand Down Expand Up @@ -315,30 +307,16 @@ __global__ void rms_norm_f16x16_f16_kernel(half* x, half* y, float g, int N, int
half2 reg_x_5 = HALF2(x[idx + 10]);
half2 reg_x_6 = HALF2(x[idx + 12]);
half2 reg_x_7 = HALF2(x[idx + 14]);
half variance = (((idx + 0) < N * K) ? (reg_x_0.x * reg_x_0.x
+ reg_x_0.y * reg_x_0.y)
: __float2half(0.0f));
variance += (((idx + 2) < N * K) ? (reg_x_1.x * reg_x_1.x
+ reg_x_1.y * reg_x_1.y)
: __float2half(0.0f));
variance += (((idx + 4) < N * K) ? (reg_x_2.x * reg_x_2.x
+ reg_x_2.y * reg_x_2.y)
: __float2half(0.0f));
variance += (((idx + 6) < N * K) ? (reg_x_3.x * reg_x_3.x
+ reg_x_3.y * reg_x_3.y)
: __float2half(0.0f));
variance += (((idx + 8) < N * K) ? (reg_x_4.x * reg_x_4.x
+ reg_x_4.y * reg_x_4.y)
: __float2half(0.0f));
variance += (((idx + 10) < N * K) ? (reg_x_5.x * reg_x_5.x
+ reg_x_5.y * reg_x_5.y)
: __float2half(0.0f));
variance += (((idx + 12) < N * K) ? (reg_x_6.x * reg_x_6.x
+ reg_x_6.y * reg_x_6.y)
: __float2half(0.0f));
variance += (((idx + 14) < N * K) ? (reg_x_7.x * reg_x_7.x
+ reg_x_7.y * reg_x_7.y)
: __float2half(0.0f));

half variance = HALF2_VARIANCE(reg_x_0, 0);
variance += HALF2_VARIANCE(reg_x_1, 2);
variance += HALF2_VARIANCE(reg_x_2, 4);
variance += HALF2_VARIANCE(reg_x_3, 6);
variance += HALF2_VARIANCE(reg_x_4, 8);
variance += HALF2_VARIANCE(reg_x_5, 10);
variance += HALF2_VARIANCE(reg_x_6, 12);
variance += HALF2_VARIANCE(reg_x_7, 14);

variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
if (tid == 0) s_variance = hrsqrt(variance / (K_ + epsilon));
// wait for s_variance in shared memory to be ready for all threads
Expand Down Expand Up @@ -394,30 +372,16 @@ __global__ void rms_norm_f16x16_f32_kernel(half* x, half* y, float g, int N, int
float2 reg_x_5 = __half22float2(HALF2(x[idx + 10]));
float2 reg_x_6 = __half22float2(HALF2(x[idx + 12]));
float2 reg_x_7 = __half22float2(HALF2(x[idx + 14]));
float variance = (((idx + 0) < N * K) ? (reg_x_0.x * reg_x_0.x
+ reg_x_0.y * reg_x_0.y)
: 0.0f);
variance += (((idx + 2) < N * K) ? (reg_x_1.x * reg_x_1.x
+ reg_x_1.y * reg_x_1.y)
: 0.0f);
variance += (((idx + 4) < N * K) ? (reg_x_2.x * reg_x_2.x
+ reg_x_2.y * reg_x_2.y)
: 0.0f);
variance += (((idx + 6) < N * K) ? (reg_x_3.x * reg_x_3.x
+ reg_x_3.y * reg_x_3.y)
: 0.0f);
variance += (((idx + 8) < N * K) ? (reg_x_4.x * reg_x_4.x
+ reg_x_4.y * reg_x_4.y)
: 0.0f);
variance += (((idx + 10) < N * K) ? (reg_x_5.x * reg_x_5.x
+ reg_x_5.y * reg_x_5.y)
: 0.0f);
variance += (((idx + 12) < N * K) ? (reg_x_6.x * reg_x_6.x
+ reg_x_6.y * reg_x_6.y)
: 0.0f);
variance += (((idx + 14) < N * K) ? (reg_x_7.x * reg_x_7.x
+ reg_x_7.y * reg_x_7.y)
: 0.0f);

float variance = FLOAT2_VARIANCE(reg_x_0, 0);
variance += FLOAT2_VARIANCE(reg_x_1, 2);
variance += FLOAT2_VARIANCE(reg_x_2, 4);
variance += FLOAT2_VARIANCE(reg_x_3, 6);
variance += FLOAT2_VARIANCE(reg_x_4, 8);
variance += FLOAT2_VARIANCE(reg_x_5, 10);
variance += FLOAT2_VARIANCE(reg_x_6, 12);
variance += FLOAT2_VARIANCE(reg_x_7, 14);

variance = block_reduce_sum_f32<NUM_THREADS>(variance);
if (tid == 0) s_variance = rsqrtf(variance / ((float) K + epsilon));
// wait for s_variance in shared memory to be ready for all threads
Expand Down Expand Up @@ -602,7 +566,7 @@ rms_norm_f16_f16_kernel<(K)><<<grid, block>>>( \
}

#define LANUCH_RMS_NORM_F16F32_KERNEL(K) \
rms_norm_f16_f32_kernel<(K)><<<grid, block>>>( \
rms_norm_f16_f32_kernel<(K)><<<grid, block>>>( \
reinterpret_cast<half*>(x.data_ptr()), \
reinterpret_cast<half*>(y.data_ptr()), \
g, N, (K));
Expand Down

0 comments on commit 5601f72

Please sign in to comment.