Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mat][Trans] Add f32/f32x4 row/col first kernel #89

Merged
merged 6 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@
| ✔️ [embedding_f16x2](./embedding/embedding.cu)|f16|/|[link](./embedding/)|⭐️|
| ✔️ [embedding_f16x8](./embedding/embedding.cu)|f16|/|[link](./embedding/)|⭐️|
| ✔️ [embedding_f16x8_pack](./embedding/embedding.cu)|f16|/|[link](./embedding/)|⭐️⭐️|
| ✔️ [mat_trans_f32_col2row{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️|
| ✔️ [mat_trans_f32_row2col{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️|
| ✔️ [mat_trans_f32_diagonal2d](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️⭐️|
| ✔️ [mat_trans_f32x4_col2row{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️⭐️|
| ✔️ [mat_trans_f32x4_row2col{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️⭐️|
| ✔️ [warp_reduce_[all]](./reduce/reduce.cu)|all|all|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_f32_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_f32x4_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
Expand Down Expand Up @@ -124,7 +129,7 @@
| ✔️ [sgemm_t_8x8_sliced_k...dbuf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k16...dbuf](./sgemm/sgemm_async.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k16...async](./sgemm/sgemm_async.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_wmma_m16n16k8...stage2/3*](./sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_wmma_m16n16k8...stages*](./sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_wmma_m16n16k8...swizzle*](./sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_naive_f16](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️|
| ✔️ [hgemm_sliced_k_f16](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
Expand All @@ -138,10 +143,9 @@
| ✔️ [hgemm_wmma_m16n16k16...naive*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...mma4x2*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...warp2x4*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...warp2x4x2*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...async*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...dbuf*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...stage3/4*](./hgemm/hgemm_wmma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...stages*](./hgemm/hgemm_wmma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...swizzle*](./hgemm/hgemm_wmma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m32n8k16...dbuf*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemv_k32_f32](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️|
Expand Down
10 changes: 10 additions & 0 deletions mat_transpose/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
*.so
*.a
*.dylib
*.dll
*.lib
.DS_Store
build
*.whl
tmp

151 changes: 151 additions & 0 deletions mat_transpose/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Mat Transpose

## 0x00 说明

包含以下内容:

- [X] mat_transpose_f32_col2row_kernel
- [X] mat_transpose_f32_row2col_kernel
- [X] mat_transpose_f32x4_col2row_kernel(float4向量化版本)
- [X] mat_transpose_f32x4_row2col_kernel(float4向量化版本)
- [X] mat_transpose_f32_diagnonal(对角轴应用于S=K)
- [ ] mat_transpose_f32x4_shared_col2row_kernel(float4向量化版本,共享内存) (施工中)
- [ ] mat_transpose_f32x4_shared_row2col_kernel(float4向量化版本,共享内存) (施工中)
- [ ] mat_transpose_f32x4_shared_bcf_col2row_kernel(float4向量化版本,共享内存,去bank conflict) (施工中)
- [ ] mat_transpose_f32x4_shared_bcf_row2col_kernel(float4向量化版本,共享内存,去bank conflict) (施工中)
- [X] PyTorch bindings



## 测试

```bash
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 mat_transpose.py
```

输出:

```bash
------------------------------------------------------------------------------------------------------------------------
S=1024, K=1024
out_original: [1.22921503, 1.82269871, -0.72512561], validate False, time:0.00008798ms
out_f32_col2row: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.03252983ms
out_f32_row2col: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.02068520ms
out_f32_col2row(2d): [1.22921503, -0.72512561, 1.82269871], validate True , time:0.02265215ms
out_f32_row2col(2d): [1.22921503, -0.72512561, 1.82269871], validate True , time:0.01682043ms
out_f32_diagnonal: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.01259637ms
out_f32x4_col2row: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.03311539ms
out_f32x4_row2col: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.01966453ms
out_f32x4_col2row(2d): [1.22921503, -0.72512561, 1.82269871], validate True , time:0.01993465ms
out_f32x4_row2col(2d): [1.22921503, -0.72512561, 1.82269871], validate True , time:0.01886630ms
out_f32_th: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.04084969ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
S=1024, K=2048
out_original: [1.68499732, 0.07425918, -0.02102743], validate False, time:0.00008655ms
out_f32_col2row: [1.68499732, -0.02102743, 0.07425918], validate True , time:0.05558133ms
out_f32_row2col: [1.68499732, -0.02102743, 0.07425918], validate True , time:0.03320456ms
out_f32_col2row(2d): [1.68499732, -0.02102743, 0.07425918], validate True , time:0.02773643ms
out_f32_row2col(2d): [1.68499732, -0.02102743, 0.07425918], validate True , time:0.02775192ms
out_f32x4_col2row: [1.68499732, -0.02102743, 0.07425918], validate True , time:0.05540919ms
out_f32x4_row2col: [1.68499732, -0.02102743, 0.07425918], validate True , time:0.03241920ms
out_f32x4_col2row(2d): [1.68499732, -0.02102743, 0.07425918], validate True , time:0.03086519ms
out_f32x4_row2col(2d): [1.68499732, -0.02102743, 0.07425918], validate True , time:0.02918243ms
out_f32_th: [1.68499732, -0.02102743, 0.07425918], validate True , time:0.05527997ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
S=1024, K=4096
out_original: [-1.25576293, -1.05169642, 0.3411217], validate False, time:0.00008583ms
out_f32_col2row: [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.10143566ms
out_f32_row2col: [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.05657411ms
out_f32_col2row(2d): [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.04857659ms
out_f32_row2col(2d): [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.04864573ms
out_f32x4_col2row: [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.10081601ms
out_f32x4_row2col: [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.05694509ms
out_f32x4_col2row(2d): [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.05282903ms
out_f32x4_row2col(2d): [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.04989004ms
out_f32_th: [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.08283234ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
S=2048, K=1024
out_original: [-0.47698042, -0.33631387, -0.16439888], validate False, time:0.00008464ms
out_f32_col2row: [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.05773354ms
out_f32_row2col: [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.03202701ms
out_f32_col2row(2d): [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.02529335ms
out_f32_row2col(2d): [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.02532363ms
out_f32x4_col2row: [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.05734038ms
out_f32x4_row2col: [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.03257370ms
out_f32x4_col2row(2d): [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.03162861ms
out_f32x4_row2col(2d): [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.02920556ms
out_f32_th: [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.05421734ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
S=2048, K=2048
out_original: [-1.11287403, -0.41300669, 0.3849003], validate False, time:0.00008488ms
out_f32_col2row: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.10564256ms
out_f32_row2col: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.05567479ms
out_f32_col2row(2d): [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.04766870ms
out_f32_row2col(2d): [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.04748774ms
out_f32_diagnonal: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.02389789ms
out_f32x4_col2row: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.10338593ms
out_f32x4_row2col: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.05683303ms
out_f32x4_col2row(2d): [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.05457044ms
out_f32x4_row2col(2d): [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.05046129ms
out_f32_th: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.08376551ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
S=2048, K=4096
out_original: [1.41623259, -0.94387418, 0.48682433], validate False, time:0.00008965ms
out_f32_col2row: [1.41623259, 0.48682433, -0.94387418], validate True , time:0.19712996ms
out_f32_row2col: [1.41623259, 0.48682433, -0.94387418], validate True , time:0.10346484ms
out_f32_col2row(2d): [1.41623259, 0.48682433, -0.94387418], validate True , time:0.08918452ms
out_f32_row2col(2d): [1.41623259, 0.48682433, -0.94387418], validate True , time:0.08975387ms
out_f32x4_col2row: [1.41623259, 0.48682433, -0.94387418], validate True , time:0.19636393ms
out_f32x4_row2col: [1.41623259, 0.48682433, -0.94387418], validate True , time:0.10541511ms
out_f32x4_col2row(2d): [1.41623259, 0.48682433, -0.94387418], validate True , time:0.09951663ms
out_f32x4_row2col(2d): [1.41623259, 0.48682433, -0.94387418], validate True , time:0.09154367ms
out_f32_th: [1.41623259, 0.48682433, -0.94387418], validate True , time:0.14955282ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
S=4096, K=1024
out_original: [-0.58965021, 0.14326878, -0.19429833], validate False, time:0.00008726ms
out_f32_col2row: [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.10833144ms
out_f32_row2col: [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.05539703ms
out_f32_col2row(2d): [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.04996872ms
out_f32_row2col(2d): [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.04996324ms
out_f32x4_col2row: [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.10815549ms
out_f32x4_row2col: [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.05626845ms
out_f32x4_col2row(2d): [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.05652213ms
out_f32x4_row2col(2d): [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.05046964ms
out_f32_th: [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.08028626ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
S=4096, K=2048
out_original: [-0.86244643, 0.61793995, -0.78971046], validate False, time:0.00008225ms
out_f32_col2row: [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.20896244ms
out_f32_row2col: [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.10261559ms
out_f32_col2row(2d): [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.09091687ms
out_f32_row2col(2d): [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.09096813ms
out_f32x4_col2row: [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.20603800ms
out_f32x4_row2col: [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.10330606ms
out_f32x4_col2row(2d): [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.10366035ms
out_f32x4_row2col(2d): [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.09077668ms
out_f32_th: [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.14721990ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
S=4096, K=4096
out_original: [-1.41012037, 0.45044342, 0.36045134], validate False, time:0.00008726ms
out_f32_col2row: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.38568211ms
out_f32_row2col: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.41187572ms
out_f32_col2row(2d): [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.21557069ms
out_f32_row2col(2d): [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.21556497ms
out_f32_diagnonal: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.30571437ms
out_f32x4_col2row: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.38697243ms
out_f32x4_row2col: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.30080318ms
out_f32x4_col2row(2d): [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.23044729ms
out_f32x4_row2col(2d): [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.34491825ms
out_f32_th: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.56499386ms
------------------------------------------------------------------------------------------------------------------------
```
Loading