Skip to content

Commit cea328b

Browse files
authored
[RMSNorm] support f16x8_f32 RMSNorm (xlite-dev#28)
* Update elementwise.cu * Update layer_norm.cu * Update relu.cu * Update rms_norm.cu * Update rms_norm.cu * Update rms_norm.py * Update README.md * Update README.md * Update README.md * Update README.md * Update rms_norm.cu * Update rms_norm.py * Update README.md
1 parent 5fe6ca6 commit cea328b

File tree

7 files changed

+378
-30
lines changed

7 files changed

+378
-30
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@
7171
| ✔️ [rms_norm_f16_f16_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
7272
| ✔️ [rms_norm_f16x2_f16_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
7373
| ✔️ [rms_norm_f16x8_f16_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
74+
| ✔️ [rms_norm_f16x8_f32_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
75+
| ✔️ [rms_norm_f16x16_f16_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
76+
| ✔️ [rms_norm_f16x16_f32_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
7477
| ✔️ [rms_norm_f16_f32_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
7578
| ✔️ [sgemm_sliced_k_f32_kernel](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
7679
| ✔️ [sgemm_t_8x8_sliced_k_f32x4_kernel](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|

elementwise/elementwise.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ __global__ void elementwise_add_f16x2_kernel(half* a, half* b, half* c, int N) {
6767
__global__ void elementwise_add_f16x8_kernel(half* a, half* b, half* c, int N) {
6868
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
6969
// manual unroll and improve L2 cache hit rate.
70-
// Only L2 cache: load 32 bytes(128 bits) in 1 memory issue (default)
71-
// Enable L1 cache: load 128 bytes(512 bits) in 1 memory issue (-Xptxas -dlcm=ca)
70+
// Only L2 cache: load 32 bytes in 1 memory issue (default)
71+
// Enable L1 cache: load 128 bytes in 1 memory issue (-Xptxas -dlcm=ca)
7272
// why try fp16x8 within 1 threads? ref: https://zhuanlan.zhihu.com/p/641639133
73-
// 0. first, tid_0 load 128 bits(32 byte) in 1 memory issue and cache data into L2 cache.
73+
// 0. first, tid_0 load 32 bytes in 1 memory issue and cache data into L2 cache.
7474
// 1. then, tid_1,...,tid_3 hit L2 cache and load data from L2 cache directly.
7575
half2 reg_a_0 = HALF2(a[idx + 0]);
7676
half2 reg_a_1 = HALF2(a[idx + 2]);

layer-norm/layer_norm.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,10 @@ __global__ void layer_norm_f16x8_f16_kernel(half* x, half* y, float g, float b,
238238
__shared__ half s_mean; // shared within block
239239
__shared__ half s_variance; // shared within block
240240
// manual unroll and improve L2 cache hit rate.
241-
// Only L2 cache: load 32 bytes(128 bits) in 1 memory issue (default)
242-
// Enable L1 cache: load 128 bytes(512 bits) in 1 memory issue (-Xptxas -dlcm=ca)
241+
// Only L2 cache: load 32 bytes in 1 memory issue (default)
242+
// Enable L1 cache: load 128 bytes in 1 memory issue (-Xptxas -dlcm=ca)
243243
// why try fp16x8 within 1 threads? ref: https://zhuanlan.zhihu.com/p/641639133
244-
// 0. first, tid_0 load 128 bits(32 byte) in 1 memory issue and cache data into L2 cache.
244+
// 0. first, tid_0 load 32 bytes in 1 memory issue and cache data into L2 cache.
245245
// 1. then, tid_1,...,tid_3 hit L2 cache and load data from L2 cache directly.
246246
half2 reg_x_0 = HALF2(x[idx + 0]);
247247
half2 reg_x_1 = HALF2(x[idx + 2]);

relu/relu.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ __global__ void relu_f16x2_kernel(half* x, half* y, int N) {
5757
__global__ void relu_f16x8_kernel(half* x, half* y, int N) {
5858
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
5959
// manual unroll and improve L2 cache hit rate.
60-
// Only L2 cache: load 32 bytes(128 bits) in 1 memory issue (default)
61-
// Enable L1 cache: load 128 bytes(512 bits) in 1 memory issue (-Xptxas -dlcm=ca)
60+
// Only L2 cache: load 32 bytes in 1 memory issue (default)
61+
// Enable L1 cache: load 128 bytes in 1 memory issue (-Xptxas -dlcm=ca)
6262
// why try fp16x8 within 1 threads? ref: https://zhuanlan.zhihu.com/p/641639133
63-
// 0. first, tid_0 load 128 bits(32 byte) in 1 memory issue and cache data into L2 cache.
63+
// 0. first, tid_0 load 32 bytes in 1 memory issue and cache data into L2 cache.
6464
// 1. then, tid_1,...,tid_3 hit L2 cache and load data from L2 cache directly.
6565
half2 reg_x_0 = HALF2(x[idx + 0]);
6666
half2 reg_x_1 = HALF2(x[idx + 2]);

rms-norm/README.md

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
- [X] rms_norm_f16_f16_kernel
1010
- [X] rms_norm_f16x2_f16_kernel
1111
- [X] rms_norm_f16x8_f16_kernel
12+
- [X] rms_norm_f16x8_f32_kernel
13+
- [X] rms_norm_f16x16_f16_kernel
14+
- [X] rms_norm_f16x16_f32_kernel
1215
- [X] rms_norm_f16_f32_kernel
1316
- [X] PyTorch bindings
1417

@@ -24,14 +27,17 @@ python3 rms_norm.py
2427

2528
```bash
2629
--------------------------------------------------------------------------------
27-
out_f32: [0.66361254, 0.69628561, 0.51101440], time:0.01188707ms
28-
out_f32x4: [0.66361260, 0.69628561, 0.51101440], time:0.00833464ms
29-
out_f32_th: [0.66361588, 0.69628906, 0.51101691], time:0.04334593ms
30+
out_f32: [0.92419142, -0.08846965, 1.06359947], time:0.03389192ms
31+
out_f32x4: [0.92419147, -0.08846966, 1.06359959], time:0.00855207ms
32+
out_f32_th: [0.92419606, -0.08847010, 1.06360483], time:0.04171062ms
3033
--------------------------------------------------------------------------------
31-
out_f16f16: [0.66357422, 0.69580078, 0.51074219], time:0.01201081ms
32-
out_f16x2f16: [0.66357422, 0.69580078, 0.51074219], time:0.00668955ms
33-
out_f16x8f16: [0.66650391, 0.69921875, 0.51318359], time:0.00398421ms
34-
out_f16f32: [0.66357422, 0.69628906, 0.51123047], time:0.01176858ms
35-
out_f16_th: [0.66357422, 0.69580078, 0.51074219], time:0.04448509ms
34+
out_f16f16: [0.92431641, -0.08843994, 1.06347656], time:0.03518176ms
35+
out_f16x2f16: [0.92431641, -0.08843994, 1.06347656], time:0.01200986ms
36+
out_f16x8f16: [0.92431641, -0.08843994, 1.06347656], time:0.00625682ms
37+
out_f16x8f32: [0.92431641, -0.08843994, 1.06347656], time:0.00625014ms
38+
out_f16x16f16: [0.92431641, -0.08843994, 1.06347656], time:0.02620339ms
39+
out_f16x16f32: [0.92431641, -0.08843994, 1.06347656], time:0.01505637ms
40+
out_f16f32: [0.92431641, -0.08843994, 1.06347656], time:0.03300810ms
41+
out_f16_th: [0.92431641, -0.08843994, 1.06347656], time:0.04187107ms
3642
--------------------------------------------------------------------------------
3743
```

0 commit comments

Comments
 (0)