Skip to content

Commit

Permalink
[RELU][Half] support fp16x8 RELU kernel (#26)
Browse files Browse the repository at this point in the history
* Update layer_norm.cu

* Update rms_norm.cu

* Update README.md

* Update elementwise.cu

* Update elementwise.py

* Update README.md

* Update README.md

* Update relu.cu

* Update relu.py

* Update README.md
  • Loading branch information
DefTruth authored Sep 14, 2024
1 parent 1ec1b3c commit 5fe6ca6
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 13 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
| ✔️ [relu_f32x4_kernel](./relu/relu.cu)|f32|/|[link](./relu/)|⭐️|
| ✔️ [relu_f16_kernel](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
| ✔️ [relu_f16x2_kernel](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
| ✔️ [relu_f16x8_kernel](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
| ✔️ [warp_reduce_f16/bf16/f32/f8/i8_kernel](./reduce/block_all_reduce.cu)|all|all|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_reduce_f32_kernel](./reduce/block_all_reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f32_f32_kernel](./reduce/block_all_reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
Expand Down
27 changes: 15 additions & 12 deletions relu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- [X] relu_f32x4_kernel(float4向量化版本)
- [X] relu_f16_kernel(fp16版本)
- [X] relu_f16x2_kernel(fp16向量化版本)
- [X] relu_f16x8_kernel(fp16向量化版本)
- [X] PyTorch bindings


Expand All @@ -23,20 +24,22 @@ python3 relu.py

```bash
--------------------------------------------------------------------------------
out_f32: [0.4740032, 0.0], time:0.01110339ms
out_f32x4: [0.4740032, 0.0], time:0.01101446ms
out_f32_th: [0.4740032, 0.0], time:0.00764275ms
out_f32: [0.0, 0.0], time:0.01072860ms
out_f32x4: [0.0, 0.0], time:0.01059222ms
out_f32_th: [0.0, 0.0], time:0.00772071ms
--------------------------------------------------------------------------------
out_f16: [0.47412109, 0.0], time:0.01086593ms
out_f16x2: [0.47412109, 0.0], time:0.01093817ms
out_f16_th: [0.47412109, 0.0], time:0.00773191ms
out_f16: [0.0, 0.0], time:0.01077199ms
out_f16x2: [0.0, 0.0], time:0.01084924ms
out_f16x8: [0.0, 0.0], time:0.01083326ms
out_f16_th: [0.0, 0.0], time:0.00762105ms
--------------------------------------------------------------------------------
out_f32(v2): [0.4740032, 0.0], time:0.00343442ms
out_f32x4(v2): [0.4740032, 0.0], time:0.00472522ms
out_f32_th: [0.4740032, 0.0], time:0.00761461ms
out_f32(v2): [0.0, 0.0], time:0.00346351ms
out_f32x4(v2): [0.0, 0.0], time:0.00342798ms
out_f32_th: [0.0, 0.0], time:0.01125073ms
--------------------------------------------------------------------------------
out_f16(v2): [0.47412109, 0.0], time:0.00342822ms
out_f16x2(v2): [0.47412109, 0.0], time:0.00345659ms
out_f16_th: [0.47412109, 0.0], time:0.00793290ms
out_f16(v2): [0.0, 0.0], time:0.00343585ms
out_f16x2(v2): [0.0, 0.0], time:0.00339842ms
out_f16x8(v2): [0.0, 0.0], time:0.00347090ms
out_f16_th: [0.0, 0.0], time:0.00776792ms
--------------------------------------------------------------------------------
```
31 changes: 31 additions & 0 deletions relu/relu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,33 @@ __global__ void relu_f16x2_kernel(half* x, half* y, int N) {
}
}

__global__ void relu_f16x8_kernel(half* x, half* y, int N) {
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
// manual unroll and improve L2 cache hit rate.
// Only L2 cache: load 32 bytes(128 bits) in 1 memory issue (default)
// Enable L1 cache: load 128 bytes(512 bits) in 1 memory issue (-Xptxas -dlcm=ca)
// why try fp16x8 within 1 threads? ref: https://zhuanlan.zhihu.com/p/641639133
// 0. first, tid_0 load 128 bits(32 byte) in 1 memory issue and cache data into L2 cache.
// 1. then, tid_1,...,tid_3 hit L2 cache and load data from L2 cache directly.
half2 reg_x_0 = HALF2(x[idx + 0]);
half2 reg_x_1 = HALF2(x[idx + 2]);
half2 reg_x_2 = HALF2(x[idx + 4]);
half2 reg_x_3 = HALF2(x[idx + 6]);
half2 reg_y_0, reg_y_1, reg_y_2, reg_y_3;
reg_y_0.x = __hmax(__float2half(0.0f), reg_x_0.x);
reg_y_0.y = __hmax(__float2half(0.0f), reg_x_0.y);
reg_y_1.x = __hmax(__float2half(0.0f), reg_x_1.x);
reg_y_1.y = __hmax(__float2half(0.0f), reg_x_1.y);
reg_y_2.x = __hmax(__float2half(0.0f), reg_x_2.x);
reg_y_2.y = __hmax(__float2half(0.0f), reg_x_2.y);
reg_y_3.x = __hmax(__float2half(0.0f), reg_x_3.x);
reg_y_3.y = __hmax(__float2half(0.0f), reg_x_3.y);
if ((idx + 0) < N) { HALF2(y[idx + 0]) = reg_y_0; }
if ((idx + 2) < N) { HALF2(y[idx + 2]) = reg_y_1; }
if ((idx + 4) < N) { HALF2(y[idx + 4]) = reg_y_2; }
if ((idx + 6) < N) { HALF2(y[idx + 6]) = reg_y_3; }
}

// --------------------- PyTorch bindings for custom kernel -----------------------
#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func) \
Expand Down Expand Up @@ -104,18 +131,22 @@ TORCH_BINDING_RELU(f32, torch::kFloat32, float, 1)
TORCH_BINDING_RELU(f32x4, torch::kFloat32, float, 4)
TORCH_BINDING_RELU(f16, torch::kHalf, half, 1)
TORCH_BINDING_RELU(f16x2, torch::kHalf, half, 2)
TORCH_BINDING_RELU(f16x8, torch::kHalf, half, 8)
TORCH_BINDING_RELU_V2(f32, torch::kFloat32, float, 1)
TORCH_BINDING_RELU_V2(f32x4, torch::kFloat32, float, 4)
TORCH_BINDING_RELU_V2(f16, torch::kHalf, half, 1)
TORCH_BINDING_RELU_V2(f16x2, torch::kHalf, half, 2)
TORCH_BINDING_RELU_V2(f16x8, torch::kHalf, half, 8)

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(relu_f32)
TORCH_BINDING_COMMON_EXTENSION(relu_f32x4)
TORCH_BINDING_COMMON_EXTENSION(relu_f16)
TORCH_BINDING_COMMON_EXTENSION(relu_f16x2)
TORCH_BINDING_COMMON_EXTENSION(relu_f16x8)
TORCH_BINDING_COMMON_EXTENSION(relu_f32_v2)
TORCH_BINDING_COMMON_EXTENSION(relu_f32x4_v2)
TORCH_BINDING_COMMON_EXTENSION(relu_f16_v2)
TORCH_BINDING_COMMON_EXTENSION(relu_f16x2_v2)
TORCH_BINDING_COMMON_EXTENSION(relu_f16x8_v2)
}
4 changes: 3 additions & 1 deletion relu/relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, tag: str,


print("-" * 80)
N_ELEMENTS = 256*92*16
N_ELEMENTS = 256*256*4
x = torch.randn((N_ELEMENTS)).cuda().float()
run_benchmark(lib.relu_f32, x, "f32")
run_benchmark(lib.relu_f32x4, x, "f32x4")
Expand All @@ -67,6 +67,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, tag: str,
x_f16 = x.half()
run_benchmark(lib.relu_f16, x_f16, "f16")
run_benchmark(lib.relu_f16x2, x_f16, "f16x2")
run_benchmark(lib.relu_f16x8, x_f16, "f16x8")
run_benchmark(torch.relu, x_f16 , "f16_th")

print("-" * 80)
Expand All @@ -81,5 +82,6 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, tag: str,
y_f16 = torch.zeros_like(x_f16).cuda().half()
run_benchmark(lib.relu_f16_v2, x_f16, "f16(v2)", y_f16)
run_benchmark(lib.relu_f16x2_v2, x_f16, "f16x2(v2)", y_f16)
run_benchmark(lib.relu_f16x8_v2, x_f16, "f16x8(v2)", y_f16)
run_benchmark(torch.relu, x_f16 , "f16_th")
print("-" * 80)

0 comments on commit 5fe6ca6

Please sign in to comment.