Skip to content

Commit

Permalink
[RMSNorm] support f16x8_f32 RMSNorm (#28)
Browse files Browse the repository at this point in the history
* Update elementwise.cu

* Update layer_norm.cu

* Update relu.cu

* Update rms_norm.cu

* Update rms_norm.cu

* Update rms_norm.py

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update rms_norm.cu

* Update rms_norm.py

* Update README.md
  • Loading branch information
DefTruth authored Sep 14, 2024
1 parent 5fe6ca6 commit cea328b
Show file tree
Hide file tree
Showing 7 changed files with 378 additions and 30 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@
| ✔️ [rms_norm_f16_f16_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x2_f16_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_f16_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_f32_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x16_f16_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x16_f32_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16_f32_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [sgemm_sliced_k_f32_kernel](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k_f32x4_kernel](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
Expand Down
6 changes: 3 additions & 3 deletions elementwise/elementwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ __global__ void elementwise_add_f16x2_kernel(half* a, half* b, half* c, int N) {
__global__ void elementwise_add_f16x8_kernel(half* a, half* b, half* c, int N) {
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
// manual unroll and improve L2 cache hit rate.
// Only L2 cache: load 32 bytes(128 bits) in 1 memory issue (default)
// Enable L1 cache: load 128 bytes(512 bits) in 1 memory issue (-Xptxas -dlcm=ca)
// Only L2 cache: load 32 bytes in 1 memory issue (default)
// Enable L1 cache: load 128 bytes in 1 memory issue (-Xptxas -dlcm=ca)
// why try fp16x8 within 1 threads? ref: https://zhuanlan.zhihu.com/p/641639133
// 0. first, tid_0 load 128 bits(32 byte) in 1 memory issue and cache data into L2 cache.
// 0. first, tid_0 load 32 bytes in 1 memory issue and cache data into L2 cache.
// 1. then, tid_1,...,tid_3 hit L2 cache and load data from L2 cache directly.
half2 reg_a_0 = HALF2(a[idx + 0]);
half2 reg_a_1 = HALF2(a[idx + 2]);
Expand Down
6 changes: 3 additions & 3 deletions layer-norm/layer_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,10 @@ __global__ void layer_norm_f16x8_f16_kernel(half* x, half* y, float g, float b,
__shared__ half s_mean; // shared within block
__shared__ half s_variance; // shared within block
// manual unroll and improve L2 cache hit rate.
// Only L2 cache: load 32 bytes(128 bits) in 1 memory issue (default)
// Enable L1 cache: load 128 bytes(512 bits) in 1 memory issue (-Xptxas -dlcm=ca)
// Only L2 cache: load 32 bytes in 1 memory issue (default)
// Enable L1 cache: load 128 bytes in 1 memory issue (-Xptxas -dlcm=ca)
// why try fp16x8 within 1 threads? ref: https://zhuanlan.zhihu.com/p/641639133
// 0. first, tid_0 load 128 bits(32 byte) in 1 memory issue and cache data into L2 cache.
// 0. first, tid_0 load 32 bytes in 1 memory issue and cache data into L2 cache.
// 1. then, tid_1,...,tid_3 hit L2 cache and load data from L2 cache directly.
half2 reg_x_0 = HALF2(x[idx + 0]);
half2 reg_x_1 = HALF2(x[idx + 2]);
Expand Down
6 changes: 3 additions & 3 deletions relu/relu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ __global__ void relu_f16x2_kernel(half* x, half* y, int N) {
__global__ void relu_f16x8_kernel(half* x, half* y, int N) {
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
// manual unroll and improve L2 cache hit rate.
// Only L2 cache: load 32 bytes(128 bits) in 1 memory issue (default)
// Enable L1 cache: load 128 bytes(512 bits) in 1 memory issue (-Xptxas -dlcm=ca)
// Only L2 cache: load 32 bytes in 1 memory issue (default)
// Enable L1 cache: load 128 bytes in 1 memory issue (-Xptxas -dlcm=ca)
// why try fp16x8 within 1 threads? ref: https://zhuanlan.zhihu.com/p/641639133
// 0. first, tid_0 load 128 bits(32 byte) in 1 memory issue and cache data into L2 cache.
// 0. first, tid_0 load 32 bytes in 1 memory issue and cache data into L2 cache.
// 1. then, tid_1,...,tid_3 hit L2 cache and load data from L2 cache directly.
half2 reg_x_0 = HALF2(x[idx + 0]);
half2 reg_x_1 = HALF2(x[idx + 2]);
Expand Down
22 changes: 14 additions & 8 deletions rms-norm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
- [X] rms_norm_f16_f16_kernel
- [X] rms_norm_f16x2_f16_kernel
- [X] rms_norm_f16x8_f16_kernel
- [X] rms_norm_f16x8_f32_kernel
- [X] rms_norm_f16x16_f16_kernel
- [X] rms_norm_f16x16_f32_kernel
- [X] rms_norm_f16_f32_kernel
- [X] PyTorch bindings

Expand All @@ -24,14 +27,17 @@ python3 rms_norm.py

```bash
--------------------------------------------------------------------------------
out_f32: [0.66361254, 0.69628561, 0.51101440], time:0.01188707ms
out_f32x4: [0.66361260, 0.69628561, 0.51101440], time:0.00833464ms
out_f32_th: [0.66361588, 0.69628906, 0.51101691], time:0.04334593ms
out_f32: [0.92419142, -0.08846965, 1.06359947], time:0.03389192ms
out_f32x4: [0.92419147, -0.08846966, 1.06359959], time:0.00855207ms
out_f32_th: [0.92419606, -0.08847010, 1.06360483], time:0.04171062ms
--------------------------------------------------------------------------------
out_f16f16: [0.66357422, 0.69580078, 0.51074219], time:0.01201081ms
out_f16x2f16: [0.66357422, 0.69580078, 0.51074219], time:0.00668955ms
out_f16x8f16: [0.66650391, 0.69921875, 0.51318359], time:0.00398421ms
out_f16f32: [0.66357422, 0.69628906, 0.51123047], time:0.01176858ms
out_f16_th: [0.66357422, 0.69580078, 0.51074219], time:0.04448509ms
out_f16f16: [0.92431641, -0.08843994, 1.06347656], time:0.03518176ms
out_f16x2f16: [0.92431641, -0.08843994, 1.06347656], time:0.01200986ms
out_f16x8f16: [0.92431641, -0.08843994, 1.06347656], time:0.00625682ms
out_f16x8f32: [0.92431641, -0.08843994, 1.06347656], time:0.00625014ms
out_f16x16f16: [0.92431641, -0.08843994, 1.06347656], time:0.02620339ms
out_f16x16f32: [0.92431641, -0.08843994, 1.06347656], time:0.01505637ms
out_f16f32: [0.92431641, -0.08843994, 1.06347656], time:0.03300810ms
out_f16_th: [0.92431641, -0.08843994, 1.06347656], time:0.04187107ms
--------------------------------------------------------------------------------
```
Loading

0 comments on commit cea328b

Please sign in to comment.