diff --git a/README.md b/README.md index 52496262..a1fd781d 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,11 @@ | ✔️ [embedding_f16x2](./embedding/embedding.cu)|f16|/|[link](./embedding/)|⭐️| | ✔️ [embedding_f16x8](./embedding/embedding.cu)|f16|/|[link](./embedding/)|⭐️| | ✔️ [embedding_f16x8_pack](./embedding/embedding.cu)|f16|/|[link](./embedding/)|⭐️⭐️| +| ✔️ [mat_trans_f32_col2row{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️| +| ✔️ [mat_trans_f32_row2col{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️| +| ✔️ [mat_trans_f32_diagonal2d](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️⭐️| +| ✔️ [mat_trans_f32x4_col2row{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️⭐️| +| ✔️ [mat_trans_f32x4_row2col{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️⭐️| | ✔️ [warp_reduce_[all]](./reduce/reduce.cu)|all|all|[link](./reduce/)|⭐️⭐️| | ✔️ [reduce_f32_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️| | ✔️ [reduce_f32x4_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️| @@ -124,7 +129,7 @@ | ✔️ [sgemm_t_8x8_sliced_k...dbuf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️| | ✔️ [sgemm_t_8x8_sliced_k16...dbuf](./sgemm/sgemm_async.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️| | ✔️ [sgemm_t_8x8_sliced_k16...async](./sgemm/sgemm_async.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️| -| ✔️ [sgemm_wmma_m16n16k8...stage2/3*](./sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./sgemm/)|⭐️⭐️⭐️| +| ✔️ [sgemm_wmma_m16n16k8...stages*](./sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./sgemm/)|⭐️⭐️⭐️| | ✔️ [sgemm_wmma_m16n16k8...swizzle*](./sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./sgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_naive_f16](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️| | ✔️ [hgemm_sliced_k_f16](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| @@ -138,10 +143,9 @@ | ✔️ [hgemm_wmma_m16n16k16...naive*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_wmma_m16n16k16...mma4x2*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_wmma_m16n16k16...warp2x4*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| -| ✔️ [hgemm_wmma_m16n16k16...warp2x4x2*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_wmma_m16n16k16...async*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_wmma_m16n16k16...dbuf*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| -| ✔️ [hgemm_wmma_m16n16k16...stage3/4*](./hgemm/hgemm_wmma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| +| ✔️ [hgemm_wmma_m16n16k16...stages*](./hgemm/hgemm_wmma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_wmma_m16n16k16...swizzle*](./hgemm/hgemm_wmma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_wmma_m32n8k16...dbuf*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| | ✔️ [sgemv_k32_f32](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️| diff --git a/mat_transpose/.gitignore b/mat_transpose/.gitignore new file mode 100644 index 00000000..eb33da95 --- /dev/null +++ b/mat_transpose/.gitignore @@ -0,0 +1,10 @@ +*.so +*.a +*.dylib +*.dll +*.lib +.DS_Store +build +*.whl +tmp + diff --git a/mat_transpose/README.md b/mat_transpose/README.md new file mode 100755 index 00000000..4417c305 --- /dev/null +++ b/mat_transpose/README.md @@ -0,0 +1,151 @@ +# Mat Transpose + +## 0x00 说明 + +包含以下内容: + +- [X] mat_transpose_f32_col2row_kernel +- [X] mat_transpose_f32_row2col_kernel +- [X] mat_transpose_f32x4_col2row_kernel(float4向量化版本) +- [X] mat_transpose_f32x4_row2col_kernel(float4向量化版本) +- [X] mat_transpose_f32_diagnonal(对角轴应用于S=K) +- [ ] mat_transpose_f32x4_shared_col2row_kernel(float4向量化版本,共享内存) (施工中) +- [ ] mat_transpose_f32x4_shared_row2col_kernel(float4向量化版本,共享内存) (施工中) +- [ ] mat_transpose_f32x4_shared_bcf_col2row_kernel(float4向量化版本,共享内存,去bank conflict) (施工中) +- [ ] mat_transpose_f32x4_shared_bcf_row2col_kernel(float4向量化版本,共享内存,去bank conflict) (施工中) +- [X] PyTorch bindings + + + +## 测试 + +```bash +# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ... +export TORCH_CUDA_ARCH_LIST=Ada +python3 mat_transpose.py +``` + +输出: + +```bash +------------------------------------------------------------------------------------------------------------------------ + S=1024, K=1024 + out_original: [1.22921503, 1.82269871, -0.72512561], validate False, time:0.00008798ms + out_f32_col2row: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.03252983ms + out_f32_row2col: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.02068520ms + out_f32_col2row(2d): [1.22921503, -0.72512561, 1.82269871], validate True , time:0.02265215ms + out_f32_row2col(2d): [1.22921503, -0.72512561, 1.82269871], validate True , time:0.01682043ms + out_f32_diagnonal: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.01259637ms + out_f32x4_col2row: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.03311539ms + out_f32x4_row2col: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.01966453ms + out_f32x4_col2row(2d): [1.22921503, -0.72512561, 1.82269871], validate True , time:0.01993465ms + out_f32x4_row2col(2d): [1.22921503, -0.72512561, 1.82269871], validate True , time:0.01886630ms + out_f32_th: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.04084969ms +------------------------------------------------------------------------------------------------------------------------ +------------------------------------------------------------------------------------------------------------------------ + S=1024, K=2048 + out_original: [1.68499732, 0.07425918, -0.02102743], validate False, time:0.00008655ms + out_f32_col2row: [1.68499732, -0.02102743, 0.07425918], validate True , time:0.05558133ms + out_f32_row2col: [1.68499732, -0.02102743, 0.07425918], validate True , time:0.03320456ms + out_f32_col2row(2d): [1.68499732, -0.02102743, 0.07425918], validate True , time:0.02773643ms + out_f32_row2col(2d): [1.68499732, -0.02102743, 0.07425918], validate True , time:0.02775192ms + out_f32x4_col2row: [1.68499732, -0.02102743, 0.07425918], validate True , time:0.05540919ms + out_f32x4_row2col: [1.68499732, -0.02102743, 0.07425918], validate True , time:0.03241920ms + out_f32x4_col2row(2d): [1.68499732, -0.02102743, 0.07425918], validate True , time:0.03086519ms + out_f32x4_row2col(2d): [1.68499732, -0.02102743, 0.07425918], validate True , time:0.02918243ms + out_f32_th: [1.68499732, -0.02102743, 0.07425918], validate True , time:0.05527997ms +------------------------------------------------------------------------------------------------------------------------ +------------------------------------------------------------------------------------------------------------------------ + S=1024, K=4096 + out_original: [-1.25576293, -1.05169642, 0.3411217], validate False, time:0.00008583ms + out_f32_col2row: [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.10143566ms + out_f32_row2col: [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.05657411ms + out_f32_col2row(2d): [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.04857659ms + out_f32_row2col(2d): [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.04864573ms + out_f32x4_col2row: [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.10081601ms + out_f32x4_row2col: [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.05694509ms + out_f32x4_col2row(2d): [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.05282903ms + out_f32x4_row2col(2d): [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.04989004ms + out_f32_th: [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.08283234ms +------------------------------------------------------------------------------------------------------------------------ +------------------------------------------------------------------------------------------------------------------------ + S=2048, K=1024 + out_original: [-0.47698042, -0.33631387, -0.16439888], validate False, time:0.00008464ms + out_f32_col2row: [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.05773354ms + out_f32_row2col: [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.03202701ms + out_f32_col2row(2d): [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.02529335ms + out_f32_row2col(2d): [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.02532363ms + out_f32x4_col2row: [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.05734038ms + out_f32x4_row2col: [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.03257370ms + out_f32x4_col2row(2d): [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.03162861ms + out_f32x4_row2col(2d): [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.02920556ms + out_f32_th: [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.05421734ms +------------------------------------------------------------------------------------------------------------------------ +------------------------------------------------------------------------------------------------------------------------ + S=2048, K=2048 + out_original: [-1.11287403, -0.41300669, 0.3849003], validate False, time:0.00008488ms + out_f32_col2row: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.10564256ms + out_f32_row2col: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.05567479ms + out_f32_col2row(2d): [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.04766870ms + out_f32_row2col(2d): [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.04748774ms + out_f32_diagnonal: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.02389789ms + out_f32x4_col2row: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.10338593ms + out_f32x4_row2col: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.05683303ms + out_f32x4_col2row(2d): [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.05457044ms + out_f32x4_row2col(2d): [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.05046129ms + out_f32_th: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.08376551ms +------------------------------------------------------------------------------------------------------------------------ +------------------------------------------------------------------------------------------------------------------------ + S=2048, K=4096 + out_original: [1.41623259, -0.94387418, 0.48682433], validate False, time:0.00008965ms + out_f32_col2row: [1.41623259, 0.48682433, -0.94387418], validate True , time:0.19712996ms + out_f32_row2col: [1.41623259, 0.48682433, -0.94387418], validate True , time:0.10346484ms + out_f32_col2row(2d): [1.41623259, 0.48682433, -0.94387418], validate True , time:0.08918452ms + out_f32_row2col(2d): [1.41623259, 0.48682433, -0.94387418], validate True , time:0.08975387ms + out_f32x4_col2row: [1.41623259, 0.48682433, -0.94387418], validate True , time:0.19636393ms + out_f32x4_row2col: [1.41623259, 0.48682433, -0.94387418], validate True , time:0.10541511ms + out_f32x4_col2row(2d): [1.41623259, 0.48682433, -0.94387418], validate True , time:0.09951663ms + out_f32x4_row2col(2d): [1.41623259, 0.48682433, -0.94387418], validate True , time:0.09154367ms + out_f32_th: [1.41623259, 0.48682433, -0.94387418], validate True , time:0.14955282ms +------------------------------------------------------------------------------------------------------------------------ +------------------------------------------------------------------------------------------------------------------------ + S=4096, K=1024 + out_original: [-0.58965021, 0.14326878, -0.19429833], validate False, time:0.00008726ms + out_f32_col2row: [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.10833144ms + out_f32_row2col: [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.05539703ms + out_f32_col2row(2d): [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.04996872ms + out_f32_row2col(2d): [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.04996324ms + out_f32x4_col2row: [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.10815549ms + out_f32x4_row2col: [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.05626845ms + out_f32x4_col2row(2d): [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.05652213ms + out_f32x4_row2col(2d): [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.05046964ms + out_f32_th: [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.08028626ms +------------------------------------------------------------------------------------------------------------------------ +------------------------------------------------------------------------------------------------------------------------ + S=4096, K=2048 + out_original: [-0.86244643, 0.61793995, -0.78971046], validate False, time:0.00008225ms + out_f32_col2row: [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.20896244ms + out_f32_row2col: [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.10261559ms + out_f32_col2row(2d): [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.09091687ms + out_f32_row2col(2d): [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.09096813ms + out_f32x4_col2row: [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.20603800ms + out_f32x4_row2col: [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.10330606ms + out_f32x4_col2row(2d): [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.10366035ms + out_f32x4_row2col(2d): [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.09077668ms + out_f32_th: [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.14721990ms +------------------------------------------------------------------------------------------------------------------------ +------------------------------------------------------------------------------------------------------------------------ + S=4096, K=4096 + out_original: [-1.41012037, 0.45044342, 0.36045134], validate False, time:0.00008726ms + out_f32_col2row: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.38568211ms + out_f32_row2col: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.41187572ms + out_f32_col2row(2d): [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.21557069ms + out_f32_row2col(2d): [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.21556497ms + out_f32_diagnonal: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.30571437ms + out_f32x4_col2row: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.38697243ms + out_f32x4_row2col: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.30080318ms + out_f32x4_col2row(2d): [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.23044729ms + out_f32x4_row2col(2d): [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.34491825ms + out_f32_th: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.56499386ms +------------------------------------------------------------------------------------------------------------------------ +``` diff --git a/mat_transpose/mat_transpose.cu b/mat_transpose/mat_transpose.cu new file mode 100644 index 00000000..b700325e --- /dev/null +++ b/mat_transpose/mat_transpose.cu @@ -0,0 +1,220 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define WARP_SIZE 256 +#define WARP_SIZE_S 16 +#define INT4(value) (reinterpret_cast(&(value))[0]) +#define FLOAT4(value) (reinterpret_cast(&(value))[0]) +#define HALF2(value) (reinterpret_cast(&(value))[0]) +#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162 *>(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) +#define MAX_EXP_F32 88.3762626647949f +#define MIN_EXP_F32 -88.3762626647949f +#define MAX_EXP_F16 __float2half(11.089866488461016f) +#define MIN_EXP_F16 __float2half(-9.704060527839234f) + +// -------------------------------------- FP32 -------------------------------------- +// col2row means read x[row][col] and write y[col][row] +// row2col means read x[col][row] and write y[row][col] +__global__ void mat_transpose_f32_col2row_kernel( + float *x, float *y, const int row, const int col) { + const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + const int global_row = global_idx / col; + const int global_col = global_idx % col; + if (global_idx < row * col) { + y[global_col * row + global_row] = x[global_idx]; + } +} + +__global__ void mat_transpose_f32_row2col_kernel( + float *x, float *y, const int row, const int col) { + const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + const int global_col = global_idx / row; + const int global_row = global_idx % row; + if (global_idx < row * col) { + y[global_idx] = x[global_row * col + global_col]; + } +} + +__global__ void mat_transpose_f32x4_col2row_kernel( + float *x, float *y, const int row, const int col) { + int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + int global_col = (global_idx * 4) % col; + int global_row = (global_idx * 4) / col; + + if (global_row < row && global_col + 3 < col) { + float4 x_val = reinterpret_cast(x)[global_idx]; + + y[global_col * row + global_row] = x_val.x; + y[(global_col + 1) * row + global_row] = x_val.y; + y[(global_col + 2) * row + global_row] = x_val.z; + y[(global_col + 3) * row + global_row] = x_val.w; + } +} +__global__ void mat_transpose_f32x4_row2col_kernel( + float *x, float *y, const int row, const int col) { + const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + const int global_col = (global_idx * 4) / row; + const int global_row = (global_idx * 4) % row; + + if (global_row < row && global_col < col) { + float4 x_val; + x_val.x = x[global_row * col + global_col]; + x_val.y = x[(global_row + 1) * col + global_col]; + x_val.z = x[(global_row + 2) * col + global_col]; + x_val.w = x[(global_row + 3) * col + global_col]; + reinterpret_cast(y)[global_idx] = FLOAT4(x_val); + } +} + +// work for row == col +__global__ void mat_transpose_f32_diagonal2d_kernel( + float *x, float *y, int row, int col) { + const int block_y = blockIdx.x; + const int block_x = (blockIdx.x + blockIdx.y) % gridDim.x; + const int global_col = threadIdx.x + blockDim.x * block_x; + const int global_row = threadIdx.y + blockDim.y * block_y; + if (global_col < col && global_row < row) { + y[global_row * col + global_col] = x[global_col * row + global_row]; + } +} + +__global__ void mat_transpose_f32_col2row2d_kernel( + float *x, float *y, const int row, const int col) { + const int global_x = blockIdx.x * blockDim.x + threadIdx.x; + const int global_y = blockIdx.y * blockDim.y + threadIdx.y; + if (global_x < col && global_y < row) { + y[global_x * row + global_y] = x[global_y * col + global_x]; + } +} + +__global__ void mat_transpose_f32_row2col2d_kernel( + float *x, float *y, const int row, const int col) { + const int global_y = blockIdx.x * blockDim.x + threadIdx.x; + const int global_x = blockIdx.y * blockDim.y + threadIdx.y; + if (global_y < col && global_x < row) { + y[global_y * row + global_x] = x[global_x * col + global_y]; + } +} + +__global__ void mat_transpose_f32x4_col2row2d_kernel( + float *x, float *y, const int row, const int col) { + const int global_x = blockIdx.x * blockDim.x + threadIdx.x; + const int global_y = blockIdx.y * blockDim.y + threadIdx.y; + if (global_x * 4 + 3 < col && global_y < row) { + float4 x_val = reinterpret_cast(x)[global_y * col / 4 + global_x]; + y[(global_x * 4) * row + global_y] = x_val.x; + y[(global_x * 4 + 1) * row + global_y] = x_val.y; + y[(global_x * 4 + 2) * row + global_y] = x_val.z; + y[(global_x * 4 + 3) * row + global_y] = x_val.w; + } +} +__global__ void mat_transpose_f32x4_row2col2d_kernel( + float *x, float *y, const int row, const int col) { + const int global_x = blockIdx.x * blockDim.x + threadIdx.x; + const int global_y = blockIdx.y * blockDim.y + threadIdx.y; + if (global_y * 4 + 3 < row && global_x < col) { + float4 x_val; + x_val.x = x[(global_y * 4) * col + global_x]; + x_val.y = x[(global_y * 4 + 1) * col + global_x]; + x_val.z = x[(global_y * 4 + 2) * col + global_x]; + x_val.w = x[(global_y * 4 + 3) * col + global_x]; + reinterpret_cast(y)[global_x * row / 4 + global_y] = FLOAT4(x_val); + } +} + +// TODO: may support shared memory optimize ? +__global__ void mat_transpose_f32x4_shared_col2row2d_kernel( + float *x, float *y, const int row, const int col) { + return; +} +__global__ void mat_transpose_f32x4_shared_row2col2d_kernel( + float *x, float *y, const int row, const int col) { + return; +} +__global__ void mat_transpose_f32x4_shared_bcf_col2row2d_kernel( + float *x, float *y, const int row, const int col) { + return; +} + +// TODO: may support fp16 mat transpose ? + +// --------------------- PyTorch bindings for custom kernel ----------------------- +#define STRINGFY(str) #str +#define TORCH_BINDING_COMMON_EXTENSION(func) \ + m.def(STRINGFY(func), &func, STRINGFY(func)); + +#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ + if (((T).options().dtype() != (th_type))) \ + { \ + std::cout << "Tensor Info:" << (T).options() << std::endl; \ + throw std::runtime_error("values must be " #th_type); \ + } + +#define TORCH_BINDING_MAT_TRANSPOSE(tag, th_type, element_type, n_pack) \ + void mat_transpose_##tag(torch::Tensor x, torch::Tensor y) \ + { \ + CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \ + CHECK_TORCH_TENSOR_DTYPE(y, (th_type)) \ + const int M = x.size(0); \ + const int N = x.size(1); \ + dim3 block(WARP_SIZE); \ + dim3 grid(((N * M + WARP_SIZE - 1) / n_pack / WARP_SIZE)); \ + mat_transpose_##tag##_kernel<<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), M, N); \ + } + +#define TORCH_BINDING_MAT_TRANSPOSE2D(tag, th_type, element_type, n_element_row, n_element_col) \ + void mat_transpose_##tag##2d(torch::Tensor x, torch::Tensor y) \ + { \ + CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \ + CHECK_TORCH_TENSOR_DTYPE(y, (th_type)) \ + const int M = x.size(0); \ + const int N = x.size(1); \ + dim3 block(WARP_SIZE_S, WARP_SIZE_S); \ + dim3 grid((N + WARP_SIZE_S - 1) / (WARP_SIZE_S * n_element_col), \ + (M + WARP_SIZE_S - 1) / (WARP_SIZE_S / n_element_row)); \ + mat_transpose_##tag##2d_kernel <<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), M, N); \ + } + +// 1d index +TORCH_BINDING_MAT_TRANSPOSE(f32_col2row, torch::kFloat32, float, 1) +TORCH_BINDING_MAT_TRANSPOSE(f32_row2col, torch::kFloat32, float, 1) +TORCH_BINDING_MAT_TRANSPOSE(f32x4_col2row, torch::kFloat32, float, 4) +TORCH_BINDING_MAT_TRANSPOSE(f32x4_row2col, torch::kFloat32, float, 4) +// 2d index. easier for diagonal +TORCH_BINDING_MAT_TRANSPOSE2D(f32_col2row, torch::kFloat32, float, 1, 1) +TORCH_BINDING_MAT_TRANSPOSE2D(f32_row2col, torch::kFloat32, float, 1, 1) +TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_col2row, torch::kFloat32, float, 1, 4) +TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_row2col, torch::kFloat32, float, 4, 1) +// diagonal index method. +TORCH_BINDING_MAT_TRANSPOSE2D(f32_diagonal, torch::kFloat32, float, 1, 1) +// TODO: may support shared memory optimize ? +// TODO: may support fp16 mat transpose ? + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // 1d index + TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32_col2row) + TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_col2row) + TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32_row2col) + TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_row2col) + // 2d index. easier for diagonal + TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32_col2row2d) + TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_col2row2d) + TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32_row2col2d) + TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_row2col2d) + // diagonal index method. + TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32_diagonal2d) +} diff --git a/mat_transpose/mat_transpose.py b/mat_transpose/mat_transpose.py new file mode 100644 index 00000000..4a6026cc --- /dev/null +++ b/mat_transpose/mat_transpose.py @@ -0,0 +1,91 @@ +import torch +import time +from torch.utils.cpp_extension import load +from typing import Optional +from functools import partial + +torch.set_grad_enabled(False) + +# Load the CUDA kernel as a python module +lib = load( + name="mat_transpose_lib", + sources=["mat_transpose.cu"], + extra_cuda_cflags=[ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + ], + extra_cflags=["-std=c++17"], +) + + +def run_benchmark( + perf_func: callable, + x: torch.Tensor, + tag: str, + out: Optional[torch.Tensor] = None, + warmup: int = 10, + iters: int = 1000, + show_all: bool = False, +): + if out is not None: + out.fill_(0) + # warmup + if out is not None: + for i in range(warmup): + perf_func(x, out) + else: + for i in range(warmup): + _ = perf_func(x) + torch.cuda.synchronize() + + start = time.time() + # iters + if out is not None: + for i in range(iters): + perf_func(x, out) + else: + for i in range(iters): + out = perf_func(x) + torch.cuda.synchronize() + end = time.time() + total_time = (end - start) * 1000 # ms + mean_time = total_time / iters + out_info = f"out_{tag}" + real_t = f"{out.T.equal(x)}" + out_val = out[:2, :2].flatten().detach().cpu().numpy().tolist()[:3] + out_val = [round(v, 8) for v in out_val] + print(f"{out_info:>30}: {out_val}, validate {real_t:<5}, time:{mean_time:.8f}ms") + if show_all: + print(out) + return out, mean_time + + +Ss = [1024, 2048, 4096] +Ks = [1024, 2048, 4096] +SKs = [(S, K) for S in Ss for K in Ks] +copy_x = lambda x: x +# show the three elements x[0][0], x[0][1], x[1][0] +for S, K in SKs: + print("-" * 120) + print(" " * 50 + f"S={S}, K={K}") + x = torch.randn((S, K)).cuda().float().contiguous() + y = torch.randn((K, S)).cuda().float().contiguous() + run_benchmark(partial(copy_x), x, "original") + run_benchmark(lib.mat_transpose_f32_col2row, x, "f32_col2row", y) + run_benchmark(lib.mat_transpose_f32_row2col, x, "f32_row2col", y) + run_benchmark(lib.mat_transpose_f32_col2row2d, x, "f32_col2row(2d)", y) + run_benchmark(lib.mat_transpose_f32_row2col2d, x, "f32_row2col(2d)", y) + if S == K: + run_benchmark(lib.mat_transpose_f32_diagonal2d, x, "f32_diagnonal", y) + run_benchmark(lib.mat_transpose_f32x4_col2row, x, "f32x4_col2row", y) + run_benchmark(lib.mat_transpose_f32x4_row2col, x, "f32x4_row2col", y) + run_benchmark(lib.mat_transpose_f32x4_col2row2d, x, "f32x4_col2row(2d)", y) + run_benchmark(lib.mat_transpose_f32x4_row2col2d, x, "f32x4_row2col(2d)", y) + run_benchmark(partial(torch.transpose_copy, dim0=0, dim1=1, out=y), x, "f32_th") + print("-" * 120)