diff --git a/artifacts/models/manual_attention/bs1/attention_bs1_0-forward/nnfusion_rt/cuda_codegen/nnfusion_rt.cu b/artifacts/models/manual_attention/bs1/attention_bs1_0-forward/nnfusion_rt/cuda_codegen/nnfusion_rt.cu index 70f34e5ed..3e643da5f 100644 --- a/artifacts/models/manual_attention/bs1/attention_bs1_0-forward/nnfusion_rt/cuda_codegen/nnfusion_rt.cu +++ b/artifacts/models/manual_attention/bs1/attention_bs1_0-forward/nnfusion_rt/cuda_codegen/nnfusion_rt.cu @@ -407,100 +407,28 @@ extern void BlockFusionKernel_float_int64_t_float_float_int64_t_float_float_floa // - name: BatchMatMul_45_0 type: float shape: Shape{1, 12, 1, 64} extern "C" __global__ void BatchMatMul_float_float_float_cuda_BatchMatMul_45(float* __restrict__ A, float* __restrict__ B, float* __restrict__ compute) { - __shared__ float A_shared[32]; - __shared__ float B_shared[2048]; - { - { - float compute_local[2]; - - - float A_shared_local[1]; - float B_shared_local[2]; - compute_local[(0)] = 0.000000e+00f; - compute_local[(1)] = 0.000000e+00f; - for (int k_outer = 0; k_outer < 2; ++k_outer) { - __syncthreads(); - A_shared[(((int)threadIdx.x))] = A[((((((int)blockIdx.x) * 64) + (k_outer * 32)) + ((int)threadIdx.x)))]; - B_shared[(((int)threadIdx.x))] = B[((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)))]; - B_shared[((((int)threadIdx.x) + 32))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 32))]; - B_shared[((((int)threadIdx.x) + 64))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 64))]; - B_shared[((((int)threadIdx.x) + 96))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 96))]; - B_shared[((((int)threadIdx.x) + 128))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 128))]; - B_shared[((((int)threadIdx.x) + 160))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 160))]; - B_shared[((((int)threadIdx.x) + 192))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 192))]; - B_shared[((((int)threadIdx.x) + 224))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 224))]; - B_shared[((((int)threadIdx.x) + 256))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 256))]; - B_shared[((((int)threadIdx.x) + 288))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 288))]; - B_shared[((((int)threadIdx.x) + 320))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 320))]; - B_shared[((((int)threadIdx.x) + 352))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 352))]; - B_shared[((((int)threadIdx.x) + 384))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 384))]; - B_shared[((((int)threadIdx.x) + 416))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 416))]; - B_shared[((((int)threadIdx.x) + 448))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 448))]; - B_shared[((((int)threadIdx.x) + 480))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 480))]; - B_shared[((((int)threadIdx.x) + 512))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 512))]; - B_shared[((((int)threadIdx.x) + 544))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 544))]; - B_shared[((((int)threadIdx.x) + 576))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 576))]; - B_shared[((((int)threadIdx.x) + 608))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 608))]; - B_shared[((((int)threadIdx.x) + 640))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 640))]; - B_shared[((((int)threadIdx.x) + 672))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 672))]; - B_shared[((((int)threadIdx.x) + 704))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 704))]; - B_shared[((((int)threadIdx.x) + 736))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 736))]; - B_shared[((((int)threadIdx.x) + 768))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 768))]; - B_shared[((((int)threadIdx.x) + 800))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 800))]; - B_shared[((((int)threadIdx.x) + 832))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 832))]; - B_shared[((((int)threadIdx.x) + 864))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 864))]; - B_shared[((((int)threadIdx.x) + 896))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 896))]; - B_shared[((((int)threadIdx.x) + 928))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 928))]; - B_shared[((((int)threadIdx.x) + 960))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 960))]; - B_shared[((((int)threadIdx.x) + 992))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 992))]; - B_shared[((((int)threadIdx.x) + 1024))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1024))]; - B_shared[((((int)threadIdx.x) + 1056))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1056))]; - B_shared[((((int)threadIdx.x) + 1088))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1088))]; - B_shared[((((int)threadIdx.x) + 1120))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1120))]; - B_shared[((((int)threadIdx.x) + 1152))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1152))]; - B_shared[((((int)threadIdx.x) + 1184))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1184))]; - B_shared[((((int)threadIdx.x) + 1216))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1216))]; - B_shared[((((int)threadIdx.x) + 1248))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1248))]; - B_shared[((((int)threadIdx.x) + 1280))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1280))]; - B_shared[((((int)threadIdx.x) + 1312))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1312))]; - B_shared[((((int)threadIdx.x) + 1344))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1344))]; - B_shared[((((int)threadIdx.x) + 1376))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1376))]; - B_shared[((((int)threadIdx.x) + 1408))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1408))]; - B_shared[((((int)threadIdx.x) + 1440))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1440))]; - B_shared[((((int)threadIdx.x) + 1472))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1472))]; - B_shared[((((int)threadIdx.x) + 1504))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1504))]; - B_shared[((((int)threadIdx.x) + 1536))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1536))]; - B_shared[((((int)threadIdx.x) + 1568))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1568))]; - B_shared[((((int)threadIdx.x) + 1600))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1600))]; - B_shared[((((int)threadIdx.x) + 1632))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1632))]; - B_shared[((((int)threadIdx.x) + 1664))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1664))]; - B_shared[((((int)threadIdx.x) + 1696))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1696))]; - B_shared[((((int)threadIdx.x) + 1728))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1728))]; - B_shared[((((int)threadIdx.x) + 1760))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1760))]; - B_shared[((((int)threadIdx.x) + 1792))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1792))]; - B_shared[((((int)threadIdx.x) + 1824))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1824))]; - B_shared[((((int)threadIdx.x) + 1856))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1856))]; - B_shared[((((int)threadIdx.x) + 1888))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1888))]; - B_shared[((((int)threadIdx.x) + 1920))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1920))]; - B_shared[((((int)threadIdx.x) + 1952))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1952))]; - B_shared[((((int)threadIdx.x) + 1984))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 1984))]; - B_shared[((((int)threadIdx.x) + 2016))] = B[(((((((int)blockIdx.x) * 4096) + (k_outer * 2048)) + ((int)threadIdx.x)) + 2016))]; - __syncthreads(); - for (int k_inner_outer = 0; k_inner_outer < 32; ++k_inner_outer) { - A_shared_local[(0)] = A_shared[(k_inner_outer)]; - B_shared_local[(0)] = B_shared[(((k_inner_outer * 64) + ((int)threadIdx.x)))]; - B_shared_local[(1)] = B_shared[((((k_inner_outer * 64) + ((int)threadIdx.x)) + 32))]; - compute_local[(0)] = (compute_local[(0)] + (A_shared_local[(0)] * B_shared_local[(0)])); - compute_local[(1)] = (compute_local[(1)] + (A_shared_local[(0)] * B_shared_local[(1)])); - } - } - compute[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)))] = compute_local[(0)]; - compute[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) + 32))] = compute_local[(1)]; + __shared__ float s_C[256]; + int batch_id = blockIdx.x / 4; + int in_batch_id = blockIdx.x & 3; + int wrap_id = threadIdx.x / 16; + int in_wrap_id = threadIdx.x & 15; + float s = 0; + for (int k = wrap_id; k < 64; k += 16) { + s += A[64 * batch_id + k] * B[64 * 64 * batch_id + 64 * k + in_batch_id * 16 + in_wrap_id]; + } + s_C[threadIdx.x] = s; + __syncthreads(); + if (threadIdx.x < 128) s_C[threadIdx.x] += s_C[threadIdx.x + 128]; __syncthreads(); + if (threadIdx.x < 64) s_C[threadIdx.x] += s_C[threadIdx.x + 64]; __syncthreads(); + if (threadIdx.x < 32) { + s = s_C[threadIdx.x]; + s += s_C[threadIdx.x + 32]; + s += __shfl_xor_sync(0xffffffff, s, 16); + // s += __shfl_xor_sync(0xffffffff, s, 8); + if (wrap_id == 0) { + compute[64 * batch_id + in_batch_id * 16 + in_wrap_id] = s; } - - } - } extern void BatchMatMul_float_float_float_cuda_BatchMatMul_45_Call(const dim3 &grids, const dim3 &blocks, unsigned mem, cudaStream_t stream, float* __restrict__ A, float* __restrict__ B, float* __restrict__ compute) { BatchMatMul_float_float_float_cuda_BatchMatMul_45<<>>(A, B, compute); @@ -639,162 +567,28 @@ extern void Multiply_float_float_float_cuda_Multiply_41_Call(const dim3 &grids, // - name: BatchMatMul_49_0 type: float shape: Shape{1, 12, 1, 64} extern "C" __global__ void BatchMatMul_float_float_float_cuda_BatchMatMul_49(float* __restrict__ A, float* __restrict__ B, float* __restrict__ compute) { - __shared__ float A_shared[64]; - __shared__ float B_shared[4096]; - { - { - float compute_local[2]; - - - float A_shared_local[1]; - float B_shared_local[2]; - compute_local[(0)] = 0.000000e+00f; - compute_local[(1)] = 0.000000e+00f; - A_shared[(((int)threadIdx.x))] = A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)))]; - A_shared[((((int)threadIdx.x) + 32))] = A[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) + 32))]; - B_shared[(((int)threadIdx.x))] = B[(((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)))]; - B_shared[((((int)threadIdx.x) + 32))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 32))]; - B_shared[((((int)threadIdx.x) + 64))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 64))]; - B_shared[((((int)threadIdx.x) + 96))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 96))]; - B_shared[((((int)threadIdx.x) + 128))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 128))]; - B_shared[((((int)threadIdx.x) + 160))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 160))]; - B_shared[((((int)threadIdx.x) + 192))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 192))]; - B_shared[((((int)threadIdx.x) + 224))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 224))]; - B_shared[((((int)threadIdx.x) + 256))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 256))]; - B_shared[((((int)threadIdx.x) + 288))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 288))]; - B_shared[((((int)threadIdx.x) + 320))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 320))]; - B_shared[((((int)threadIdx.x) + 352))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 352))]; - B_shared[((((int)threadIdx.x) + 384))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 384))]; - B_shared[((((int)threadIdx.x) + 416))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 416))]; - B_shared[((((int)threadIdx.x) + 448))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 448))]; - B_shared[((((int)threadIdx.x) + 480))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 480))]; - B_shared[((((int)threadIdx.x) + 512))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 512))]; - B_shared[((((int)threadIdx.x) + 544))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 544))]; - B_shared[((((int)threadIdx.x) + 576))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 576))]; - B_shared[((((int)threadIdx.x) + 608))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 608))]; - B_shared[((((int)threadIdx.x) + 640))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 640))]; - B_shared[((((int)threadIdx.x) + 672))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 672))]; - B_shared[((((int)threadIdx.x) + 704))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 704))]; - B_shared[((((int)threadIdx.x) + 736))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 736))]; - B_shared[((((int)threadIdx.x) + 768))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 768))]; - B_shared[((((int)threadIdx.x) + 800))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 800))]; - B_shared[((((int)threadIdx.x) + 832))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 832))]; - B_shared[((((int)threadIdx.x) + 864))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 864))]; - B_shared[((((int)threadIdx.x) + 896))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 896))]; - B_shared[((((int)threadIdx.x) + 928))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 928))]; - B_shared[((((int)threadIdx.x) + 960))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 960))]; - B_shared[((((int)threadIdx.x) + 992))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 992))]; - B_shared[((((int)threadIdx.x) + 1024))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1024))]; - B_shared[((((int)threadIdx.x) + 1056))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1056))]; - B_shared[((((int)threadIdx.x) + 1088))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1088))]; - B_shared[((((int)threadIdx.x) + 1120))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1120))]; - B_shared[((((int)threadIdx.x) + 1152))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1152))]; - B_shared[((((int)threadIdx.x) + 1184))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1184))]; - B_shared[((((int)threadIdx.x) + 1216))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1216))]; - B_shared[((((int)threadIdx.x) + 1248))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1248))]; - B_shared[((((int)threadIdx.x) + 1280))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1280))]; - B_shared[((((int)threadIdx.x) + 1312))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1312))]; - B_shared[((((int)threadIdx.x) + 1344))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1344))]; - B_shared[((((int)threadIdx.x) + 1376))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1376))]; - B_shared[((((int)threadIdx.x) + 1408))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1408))]; - B_shared[((((int)threadIdx.x) + 1440))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1440))]; - B_shared[((((int)threadIdx.x) + 1472))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1472))]; - B_shared[((((int)threadIdx.x) + 1504))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1504))]; - B_shared[((((int)threadIdx.x) + 1536))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1536))]; - B_shared[((((int)threadIdx.x) + 1568))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1568))]; - B_shared[((((int)threadIdx.x) + 1600))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1600))]; - B_shared[((((int)threadIdx.x) + 1632))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1632))]; - B_shared[((((int)threadIdx.x) + 1664))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1664))]; - B_shared[((((int)threadIdx.x) + 1696))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1696))]; - B_shared[((((int)threadIdx.x) + 1728))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1728))]; - B_shared[((((int)threadIdx.x) + 1760))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1760))]; - B_shared[((((int)threadIdx.x) + 1792))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1792))]; - B_shared[((((int)threadIdx.x) + 1824))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1824))]; - B_shared[((((int)threadIdx.x) + 1856))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1856))]; - B_shared[((((int)threadIdx.x) + 1888))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1888))]; - B_shared[((((int)threadIdx.x) + 1920))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1920))]; - B_shared[((((int)threadIdx.x) + 1952))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1952))]; - B_shared[((((int)threadIdx.x) + 1984))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1984))]; - B_shared[((((int)threadIdx.x) + 2016))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2016))]; - B_shared[((((int)threadIdx.x) + 2048))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2048))]; - B_shared[((((int)threadIdx.x) + 2080))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2080))]; - B_shared[((((int)threadIdx.x) + 2112))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2112))]; - B_shared[((((int)threadIdx.x) + 2144))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2144))]; - B_shared[((((int)threadIdx.x) + 2176))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2176))]; - B_shared[((((int)threadIdx.x) + 2208))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2208))]; - B_shared[((((int)threadIdx.x) + 2240))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2240))]; - B_shared[((((int)threadIdx.x) + 2272))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2272))]; - B_shared[((((int)threadIdx.x) + 2304))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2304))]; - B_shared[((((int)threadIdx.x) + 2336))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2336))]; - B_shared[((((int)threadIdx.x) + 2368))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2368))]; - B_shared[((((int)threadIdx.x) + 2400))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2400))]; - B_shared[((((int)threadIdx.x) + 2432))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2432))]; - B_shared[((((int)threadIdx.x) + 2464))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2464))]; - B_shared[((((int)threadIdx.x) + 2496))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2496))]; - B_shared[((((int)threadIdx.x) + 2528))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2528))]; - B_shared[((((int)threadIdx.x) + 2560))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2560))]; - B_shared[((((int)threadIdx.x) + 2592))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2592))]; - B_shared[((((int)threadIdx.x) + 2624))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2624))]; - B_shared[((((int)threadIdx.x) + 2656))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2656))]; - B_shared[((((int)threadIdx.x) + 2688))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2688))]; - B_shared[((((int)threadIdx.x) + 2720))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2720))]; - B_shared[((((int)threadIdx.x) + 2752))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2752))]; - B_shared[((((int)threadIdx.x) + 2784))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2784))]; - B_shared[((((int)threadIdx.x) + 2816))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2816))]; - B_shared[((((int)threadIdx.x) + 2848))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2848))]; - B_shared[((((int)threadIdx.x) + 2880))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2880))]; - B_shared[((((int)threadIdx.x) + 2912))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2912))]; - B_shared[((((int)threadIdx.x) + 2944))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2944))]; - B_shared[((((int)threadIdx.x) + 2976))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2976))]; - B_shared[((((int)threadIdx.x) + 3008))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3008))]; - B_shared[((((int)threadIdx.x) + 3040))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3040))]; - B_shared[((((int)threadIdx.x) + 3072))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3072))]; - B_shared[((((int)threadIdx.x) + 3104))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3104))]; - B_shared[((((int)threadIdx.x) + 3136))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3136))]; - B_shared[((((int)threadIdx.x) + 3168))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3168))]; - B_shared[((((int)threadIdx.x) + 3200))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3200))]; - B_shared[((((int)threadIdx.x) + 3232))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3232))]; - B_shared[((((int)threadIdx.x) + 3264))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3264))]; - B_shared[((((int)threadIdx.x) + 3296))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3296))]; - B_shared[((((int)threadIdx.x) + 3328))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3328))]; - B_shared[((((int)threadIdx.x) + 3360))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3360))]; - B_shared[((((int)threadIdx.x) + 3392))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3392))]; - B_shared[((((int)threadIdx.x) + 3424))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3424))]; - B_shared[((((int)threadIdx.x) + 3456))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3456))]; - B_shared[((((int)threadIdx.x) + 3488))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3488))]; - B_shared[((((int)threadIdx.x) + 3520))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3520))]; - B_shared[((((int)threadIdx.x) + 3552))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3552))]; - B_shared[((((int)threadIdx.x) + 3584))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3584))]; - B_shared[((((int)threadIdx.x) + 3616))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3616))]; - B_shared[((((int)threadIdx.x) + 3648))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3648))]; - B_shared[((((int)threadIdx.x) + 3680))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3680))]; - B_shared[((((int)threadIdx.x) + 3712))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3712))]; - B_shared[((((int)threadIdx.x) + 3744))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3744))]; - B_shared[((((int)threadIdx.x) + 3776))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3776))]; - B_shared[((((int)threadIdx.x) + 3808))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3808))]; - B_shared[((((int)threadIdx.x) + 3840))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3840))]; - B_shared[((((int)threadIdx.x) + 3872))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3872))]; - B_shared[((((int)threadIdx.x) + 3904))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3904))]; - B_shared[((((int)threadIdx.x) + 3936))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3936))]; - B_shared[((((int)threadIdx.x) + 3968))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3968))]; - B_shared[((((int)threadIdx.x) + 4000))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 4000))]; - B_shared[((((int)threadIdx.x) + 4032))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 4032))]; - B_shared[((((int)threadIdx.x) + 4064))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 4064))]; - __syncthreads(); - for (int k_inner_outer = 0; k_inner_outer < 64; ++k_inner_outer) { - A_shared_local[(0)] = A_shared[(k_inner_outer)]; - B_shared_local[(0)] = B_shared[(((k_inner_outer * 64) + ((int)threadIdx.x)))]; - B_shared_local[(1)] = B_shared[((((k_inner_outer * 64) + ((int)threadIdx.x)) + 32))]; - compute_local[(0)] = (compute_local[(0)] + (A_shared_local[(0)] * B_shared_local[(0)])); - compute_local[(1)] = (compute_local[(1)] + (A_shared_local[(0)] * B_shared_local[(1)])); - } - compute[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)))] = compute_local[(0)]; - compute[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) + 32))] = compute_local[(1)]; + __shared__ float s_C[256]; + int batch_id = blockIdx.x / 4; + int in_batch_id = blockIdx.x & 3; + int wrap_id = threadIdx.x / 16; + int in_wrap_id = threadIdx.x & 15; + float s = 0; + for (int k = wrap_id; k < 64; k += 16) { + s += A[64 * batch_id + k] * B[64 * 64 * batch_id + 64 * k + in_batch_id * 16 + in_wrap_id]; + } + s_C[threadIdx.x] = s; + __syncthreads(); + if (threadIdx.x < 128) s_C[threadIdx.x] += s_C[threadIdx.x + 128]; __syncthreads(); + if (threadIdx.x < 64) s_C[threadIdx.x] += s_C[threadIdx.x + 64]; __syncthreads(); + if (threadIdx.x < 32) { + s = s_C[threadIdx.x]; + s += s_C[threadIdx.x + 32]; + s += __shfl_xor_sync(0xffffffff, s, 16); + // s += __shfl_xor_sync(0xffffffff, s, 8); + if (wrap_id == 0) { + compute[64 * batch_id + in_batch_id * 16 + in_wrap_id] = s; } - - } - } extern void BatchMatMul_float_float_float_cuda_BatchMatMul_49_Call(const dim3 &grids, const dim3 &blocks, unsigned mem, cudaStream_t stream, float* __restrict__ A, float* __restrict__ B, float* __restrict__ compute) { BatchMatMul_float_float_float_cuda_BatchMatMul_49<<>>(A, B, compute); @@ -1048,169 +842,35 @@ void Constant_char_cuda_Constant_29(cudaStream_t stream, char* output0) // - name: BatchMatMul_10_0 type: float shape: Shape{1, 12, 1, 64} __device__ __noinline__ void BatchMatMul_float_float_float_cuda_BatchMatMul_10_block_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ compute, int thread_id, int block_id, char *shared_buffer) { - if (thread_id >= 32){ - for (int i = 0; i < 1; i++) __syncthreads(); + if (thread_id >= 256){ + for (int i = 0; i < 3; i++) __syncthreads(); return; } - const dim3 blockDim(32, 1, 1); - const dim3 gridDim(12, 1, 1); + const dim3 blockDim(256, 1, 1); + const dim3 gridDim(48, 1, 1); const dim3 blockIdx(block_id, 0, 0); - float* A_shared = (float*)(shared_buffer + 0); - float* B_shared = (float*)(shared_buffer + 256); - { - { - float compute_local[2]; - - - float A_shared_local[1]; - float B_shared_local[2]; - compute_local[(0)] = 0.000000e+00f; - compute_local[(1)] = 0.000000e+00f; - A_shared[(((int)threadIdx.x))] = A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)))]; - A_shared[((((int)threadIdx.x) + 32))] = A[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) + 32))]; - B_shared[(((int)threadIdx.x))] = B[(((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)))]; - B_shared[((((int)threadIdx.x) + 32))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 32))]; - B_shared[((((int)threadIdx.x) + 64))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 64))]; - B_shared[((((int)threadIdx.x) + 96))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 96))]; - B_shared[((((int)threadIdx.x) + 128))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 128))]; - B_shared[((((int)threadIdx.x) + 160))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 160))]; - B_shared[((((int)threadIdx.x) + 192))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 192))]; - B_shared[((((int)threadIdx.x) + 224))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 224))]; - B_shared[((((int)threadIdx.x) + 256))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 256))]; - B_shared[((((int)threadIdx.x) + 288))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 288))]; - B_shared[((((int)threadIdx.x) + 320))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 320))]; - B_shared[((((int)threadIdx.x) + 352))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 352))]; - B_shared[((((int)threadIdx.x) + 384))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 384))]; - B_shared[((((int)threadIdx.x) + 416))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 416))]; - B_shared[((((int)threadIdx.x) + 448))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 448))]; - B_shared[((((int)threadIdx.x) + 480))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 480))]; - B_shared[((((int)threadIdx.x) + 512))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 512))]; - B_shared[((((int)threadIdx.x) + 544))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 544))]; - B_shared[((((int)threadIdx.x) + 576))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 576))]; - B_shared[((((int)threadIdx.x) + 608))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 608))]; - B_shared[((((int)threadIdx.x) + 640))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 640))]; - B_shared[((((int)threadIdx.x) + 672))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 672))]; - B_shared[((((int)threadIdx.x) + 704))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 704))]; - B_shared[((((int)threadIdx.x) + 736))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 736))]; - B_shared[((((int)threadIdx.x) + 768))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 768))]; - B_shared[((((int)threadIdx.x) + 800))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 800))]; - B_shared[((((int)threadIdx.x) + 832))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 832))]; - B_shared[((((int)threadIdx.x) + 864))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 864))]; - B_shared[((((int)threadIdx.x) + 896))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 896))]; - B_shared[((((int)threadIdx.x) + 928))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 928))]; - B_shared[((((int)threadIdx.x) + 960))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 960))]; - B_shared[((((int)threadIdx.x) + 992))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 992))]; - B_shared[((((int)threadIdx.x) + 1024))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1024))]; - B_shared[((((int)threadIdx.x) + 1056))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1056))]; - B_shared[((((int)threadIdx.x) + 1088))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1088))]; - B_shared[((((int)threadIdx.x) + 1120))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1120))]; - B_shared[((((int)threadIdx.x) + 1152))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1152))]; - B_shared[((((int)threadIdx.x) + 1184))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1184))]; - B_shared[((((int)threadIdx.x) + 1216))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1216))]; - B_shared[((((int)threadIdx.x) + 1248))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1248))]; - B_shared[((((int)threadIdx.x) + 1280))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1280))]; - B_shared[((((int)threadIdx.x) + 1312))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1312))]; - B_shared[((((int)threadIdx.x) + 1344))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1344))]; - B_shared[((((int)threadIdx.x) + 1376))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1376))]; - B_shared[((((int)threadIdx.x) + 1408))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1408))]; - B_shared[((((int)threadIdx.x) + 1440))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1440))]; - B_shared[((((int)threadIdx.x) + 1472))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1472))]; - B_shared[((((int)threadIdx.x) + 1504))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1504))]; - B_shared[((((int)threadIdx.x) + 1536))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1536))]; - B_shared[((((int)threadIdx.x) + 1568))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1568))]; - B_shared[((((int)threadIdx.x) + 1600))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1600))]; - B_shared[((((int)threadIdx.x) + 1632))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1632))]; - B_shared[((((int)threadIdx.x) + 1664))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1664))]; - B_shared[((((int)threadIdx.x) + 1696))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1696))]; - B_shared[((((int)threadIdx.x) + 1728))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1728))]; - B_shared[((((int)threadIdx.x) + 1760))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1760))]; - B_shared[((((int)threadIdx.x) + 1792))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1792))]; - B_shared[((((int)threadIdx.x) + 1824))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1824))]; - B_shared[((((int)threadIdx.x) + 1856))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1856))]; - B_shared[((((int)threadIdx.x) + 1888))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1888))]; - B_shared[((((int)threadIdx.x) + 1920))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1920))]; - B_shared[((((int)threadIdx.x) + 1952))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1952))]; - B_shared[((((int)threadIdx.x) + 1984))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 1984))]; - B_shared[((((int)threadIdx.x) + 2016))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2016))]; - B_shared[((((int)threadIdx.x) + 2048))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2048))]; - B_shared[((((int)threadIdx.x) + 2080))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2080))]; - B_shared[((((int)threadIdx.x) + 2112))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2112))]; - B_shared[((((int)threadIdx.x) + 2144))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2144))]; - B_shared[((((int)threadIdx.x) + 2176))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2176))]; - B_shared[((((int)threadIdx.x) + 2208))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2208))]; - B_shared[((((int)threadIdx.x) + 2240))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2240))]; - B_shared[((((int)threadIdx.x) + 2272))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2272))]; - B_shared[((((int)threadIdx.x) + 2304))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2304))]; - B_shared[((((int)threadIdx.x) + 2336))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2336))]; - B_shared[((((int)threadIdx.x) + 2368))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2368))]; - B_shared[((((int)threadIdx.x) + 2400))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2400))]; - B_shared[((((int)threadIdx.x) + 2432))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2432))]; - B_shared[((((int)threadIdx.x) + 2464))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2464))]; - B_shared[((((int)threadIdx.x) + 2496))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2496))]; - B_shared[((((int)threadIdx.x) + 2528))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2528))]; - B_shared[((((int)threadIdx.x) + 2560))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2560))]; - B_shared[((((int)threadIdx.x) + 2592))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2592))]; - B_shared[((((int)threadIdx.x) + 2624))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2624))]; - B_shared[((((int)threadIdx.x) + 2656))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2656))]; - B_shared[((((int)threadIdx.x) + 2688))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2688))]; - B_shared[((((int)threadIdx.x) + 2720))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2720))]; - B_shared[((((int)threadIdx.x) + 2752))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2752))]; - B_shared[((((int)threadIdx.x) + 2784))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2784))]; - B_shared[((((int)threadIdx.x) + 2816))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2816))]; - B_shared[((((int)threadIdx.x) + 2848))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2848))]; - B_shared[((((int)threadIdx.x) + 2880))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2880))]; - B_shared[((((int)threadIdx.x) + 2912))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2912))]; - B_shared[((((int)threadIdx.x) + 2944))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2944))]; - B_shared[((((int)threadIdx.x) + 2976))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 2976))]; - B_shared[((((int)threadIdx.x) + 3008))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3008))]; - B_shared[((((int)threadIdx.x) + 3040))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3040))]; - B_shared[((((int)threadIdx.x) + 3072))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3072))]; - B_shared[((((int)threadIdx.x) + 3104))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3104))]; - B_shared[((((int)threadIdx.x) + 3136))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3136))]; - B_shared[((((int)threadIdx.x) + 3168))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3168))]; - B_shared[((((int)threadIdx.x) + 3200))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3200))]; - B_shared[((((int)threadIdx.x) + 3232))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3232))]; - B_shared[((((int)threadIdx.x) + 3264))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3264))]; - B_shared[((((int)threadIdx.x) + 3296))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3296))]; - B_shared[((((int)threadIdx.x) + 3328))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3328))]; - B_shared[((((int)threadIdx.x) + 3360))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3360))]; - B_shared[((((int)threadIdx.x) + 3392))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3392))]; - B_shared[((((int)threadIdx.x) + 3424))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3424))]; - B_shared[((((int)threadIdx.x) + 3456))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3456))]; - B_shared[((((int)threadIdx.x) + 3488))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3488))]; - B_shared[((((int)threadIdx.x) + 3520))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3520))]; - B_shared[((((int)threadIdx.x) + 3552))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3552))]; - B_shared[((((int)threadIdx.x) + 3584))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3584))]; - B_shared[((((int)threadIdx.x) + 3616))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3616))]; - B_shared[((((int)threadIdx.x) + 3648))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3648))]; - B_shared[((((int)threadIdx.x) + 3680))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3680))]; - B_shared[((((int)threadIdx.x) + 3712))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3712))]; - B_shared[((((int)threadIdx.x) + 3744))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3744))]; - B_shared[((((int)threadIdx.x) + 3776))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3776))]; - B_shared[((((int)threadIdx.x) + 3808))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3808))]; - B_shared[((((int)threadIdx.x) + 3840))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3840))]; - B_shared[((((int)threadIdx.x) + 3872))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3872))]; - B_shared[((((int)threadIdx.x) + 3904))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3904))]; - B_shared[((((int)threadIdx.x) + 3936))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3936))]; - B_shared[((((int)threadIdx.x) + 3968))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 3968))]; - B_shared[((((int)threadIdx.x) + 4000))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 4000))]; - B_shared[((((int)threadIdx.x) + 4032))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 4032))]; - B_shared[((((int)threadIdx.x) + 4064))] = B[((((((int)blockIdx.x) * 4096) + ((int)threadIdx.x)) + 4064))]; - __syncthreads(); - for (int k_inner_outer = 0; k_inner_outer < 64; ++k_inner_outer) { - A_shared_local[(0)] = A_shared[(k_inner_outer)]; - B_shared_local[(0)] = B_shared[(((k_inner_outer * 64) + ((int)threadIdx.x)))]; - B_shared_local[(1)] = B_shared[((((k_inner_outer * 64) + ((int)threadIdx.x)) + 32))]; - compute_local[(0)] = (compute_local[(0)] + (A_shared_local[(0)] * B_shared_local[(0)])); - compute_local[(1)] = (compute_local[(1)] + (A_shared_local[(0)] * B_shared_local[(1)])); - } - compute[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)))] = compute_local[(0)]; - compute[((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) + 32))] = compute_local[(1)]; + float* s_C = (float*)(shared_buffer + 0); + int batch_id = blockIdx.x / 4; + int in_batch_id = blockIdx.x & 3; + int wrap_id = threadIdx.x / 16; + int in_wrap_id = threadIdx.x & 15; + float s = 0; + for (int k = wrap_id; k < 64; k += 16) { + s += A[64 * batch_id + k] * B[64 * 64 * batch_id + 64 * k + in_batch_id * 16 + in_wrap_id]; + } + s_C[threadIdx.x] = s; + __syncthreads(); + if (threadIdx.x < 128) s_C[threadIdx.x] += s_C[threadIdx.x + 128]; __syncthreads(); + if (threadIdx.x < 64) s_C[threadIdx.x] += s_C[threadIdx.x + 64]; __syncthreads(); + if (threadIdx.x < 32) { + s = s_C[threadIdx.x]; + s += s_C[threadIdx.x + 32]; + s += __shfl_xor_sync(0xffffffff, s, 16); + // s += __shfl_xor_sync(0xffffffff, s, 8); + if (wrap_id == 0) { + compute[64 * batch_id + in_batch_id * 16 + in_wrap_id] = s; } - - } - } // Node name: Broadcast_40 // Description: Broadcast @@ -1266,23 +926,23 @@ __device__ __noinline__ void Broadcast_float_float_cuda_Broadcast_40_block_kerne } extern "C" __global__ void BlockFusionKernel_float_float_float_float_float_float_float_float_float_cuda_BatchMatMul_BatchMatMul_BatchMatMul_Broadcast_1(float* input0, float* input1, float* input2, float* input3, float* input4, float* output0, float* output1, float* output2, float* output3) { - __shared__ char shared_buffer[16640]; + __shared__ char shared_buffer[1024]; - if ((int)blockIdx.x >= 0 && (int)blockIdx.x <= 11) + if ((int)blockIdx.x >= 0 && (int)blockIdx.x <= 47) { BatchMatMul_float_float_float_cuda_BatchMatMul_10_block_kernel(input0, input1, output0, threadIdx.x, blockIdx.x - 0 + 0, shared_buffer); } - else if ((int)blockIdx.x >= 12 && (int)blockIdx.x <= 23) + else if ((int)blockIdx.x >= 48 && (int)blockIdx.x <= 95) { - BatchMatMul_float_float_float_cuda_BatchMatMul_10_block_kernel(input0, input2, output1, threadIdx.x, blockIdx.x - 12 + 0, shared_buffer); + BatchMatMul_float_float_float_cuda_BatchMatMul_10_block_kernel(input0, input2, output1, threadIdx.x, blockIdx.x - 48 + 0, shared_buffer); } - else if ((int)blockIdx.x >= 24 && (int)blockIdx.x <= 35) + else if ((int)blockIdx.x >= 96 && (int)blockIdx.x <= 143) { - BatchMatMul_float_float_float_cuda_BatchMatMul_10_block_kernel(input0, input3, output2, threadIdx.x, blockIdx.x - 24 + 0, shared_buffer); + BatchMatMul_float_float_float_cuda_BatchMatMul_10_block_kernel(input0, input3, output2, threadIdx.x, blockIdx.x - 96 + 0, shared_buffer); } - else if ((int)blockIdx.x >= 36 && (int)blockIdx.x <= 47) + else if ((int)blockIdx.x >= 144 && (int)blockIdx.x <= 155) { - Broadcast_float_float_cuda_Broadcast_40_block_kernel(input4, output3, threadIdx.x, blockIdx.x - 36 + 0, shared_buffer); + Broadcast_float_float_cuda_Broadcast_40_block_kernel(input4, output3, threadIdx.x, blockIdx.x - 144 + 0, shared_buffer); } } @@ -1453,7 +1113,7 @@ Reshape_float_float_cuda_Reshape_22_Call(dim3(12, 1, 1), dim3(64, 1, 1), 0, 0, P // name=Reshape_9 // eliminated: Reshape_float_float_cuda_Reshape_9_Call(dim3(768, 1, 1), dim3(64, 1, 1), 0, 0, Constant_0_0, Reshape_9_0); // name=blockfusion_kernel_58 -BlockFusionKernel_float_float_float_float_float_float_float_float_float_cuda_BatchMatMul_BatchMatMul_BatchMatMul_Broadcast_1_Call(dim3(48, 1, 1), dim3(64, 1, 1), 0, 0, Reshape_22_0, Reshape_9_0, Reshape_13_0, Reshape_23_0, Reshape_39_0, BatchMatMul_10_0, BatchMatMul_14_0, BatchMatMul_24_0, Broadcast_40_0); +BlockFusionKernel_float_float_float_float_float_float_float_float_float_cuda_BatchMatMul_BatchMatMul_BatchMatMul_Broadcast_1_Call(dim3(156, 1, 1), dim3(256, 1, 1), 0, 0, Reshape_22_0, Reshape_9_0, Reshape_13_0, Reshape_23_0, Reshape_39_0, BatchMatMul_10_0, BatchMatMul_14_0, BatchMatMul_24_0, Broadcast_40_0); // name=Reshape_25 // eliminated: Reshape_float_float_cuda_Reshape_22_Call(dim3(12, 1, 1), dim3(64, 1, 1), 0, 0, BatchMatMul_24_0, Reshape_25_0); // name=@tmp_9 @@ -1496,14 +1156,14 @@ Softmax_float_float_cuda_Softmax_42_Call(dim3(6, 1, 1), dim3(64, 1, 1), 0, 0, Mu // name=Reshape_43 // eliminated: Reshape_float_float_cuda_Reshape_22_Call(dim3(12, 1, 1), dim3(64, 1, 1), 0, 0, Softmax_42_0, Reshape_43_0); // name=x_1 -BatchMatMul_float_float_float_cuda_BatchMatMul_45_Call(dim3(12, 1, 1), dim3(32, 1, 1), 0, 0, Reshape_43_0, Parameter_7_0, BatchMatMul_45_0); +BatchMatMul_float_float_float_cuda_BatchMatMul_45_Call(dim3(48, 1, 1), dim3(256, 1, 1), 0, 0, Reshape_43_0, Parameter_7_0, BatchMatMul_45_0); // DEBUG_TENSOR(Reshape_43_0, 1); // name=Reshape_46 // eliminated: Reshape_float_float_cuda_Reshape_22_Call(dim3(12, 1, 1), dim3(64, 1, 1), 0, 0, BatchMatMul_45_0, Reshape_46_0); // name=Reshape_47 // eliminated: Reshape_float_float_cuda_Reshape_22_Call(dim3(12, 1, 1), dim3(64, 1, 1), 0, 0, Reshape_46_0, Reshape_47_0); // name=x_2 -BatchMatMul_float_float_float_cuda_BatchMatMul_49_Call(dim3(12, 1, 1), dim3(32, 1, 1), 0, 0, Reshape_47_0, Reshape_48_0, BatchMatMul_49_0); +BatchMatMul_float_float_float_cuda_BatchMatMul_49_Call(dim3(48, 1, 1), dim3(256, 1, 1), 0, 0, Reshape_47_0, Reshape_48_0, BatchMatMul_49_0); // name=Reshape_50 Reshape_float_float_cuda_Reshape_22_Call(dim3(12, 1, 1), dim3(64, 1, 1), 0, 0, BatchMatMul_49_0, Result_55_0); // name=Result_55 diff --git a/artifacts/models/manual_attention/bs1/run.py b/artifacts/models/manual_attention/bs1/run.py index aa081b3ea..b3e3694e1 100644 --- a/artifacts/models/manual_attention/bs1/run.py +++ b/artifacts/models/manual_attention/bs1/run.py @@ -28,6 +28,7 @@ def forward(self, x, k, v): SIZE_PER_HEAD = 64 n_warmup = 100 n_run = 100 +torch.random.manual_seed(0) if __name__ == "__main__": batch_size = 1 @@ -41,6 +42,7 @@ def forward(self, x, k, v): for i in range(n_warmup): _ = model.forward(x, k, v) torch.cuda.synchronize() + print(_) # run timer = Timer("ms") enable_profile('V100') diff --git a/artifacts/models/manual_seq2seq/bs1/run.py b/artifacts/models/manual_seq2seq/bs1/run.py index 71535f593..9c4a6bff3 100644 --- a/artifacts/models/manual_seq2seq/bs1/run.py +++ b/artifacts/models/manual_seq2seq/bs1/run.py @@ -183,7 +183,7 @@ def read_bin(s, dtype=np.float32): mask = masks[:, i:i+batch_size].contiguous() torch.cuda.synchronize() output_all, h = model(encoder_output, mask, h, c) - if i == 0: print(output_all) + if i == 0: print(output_all, h) torch.cuda.synchronize() # run timer = Timer("ms") diff --git a/artifacts/models/manual_seq2seq/bs1/seq2seq_bs1_0-forward/nnfusion_rt/cuda_codegen/nnfusion_rt.cu b/artifacts/models/manual_seq2seq/bs1/seq2seq_bs1_0-forward/nnfusion_rt/cuda_codegen/nnfusion_rt.cu index 0550fd9dc..857cc3b54 100644 --- a/artifacts/models/manual_seq2seq/bs1/seq2seq_bs1_0-forward/nnfusion_rt/cuda_codegen/nnfusion_rt.cu +++ b/artifacts/models/manual_seq2seq/bs1/seq2seq_bs1_0-forward/nnfusion_rt/cuda_codegen/nnfusion_rt.cu @@ -654,36 +654,30 @@ extern void FusedKernel_float_float_float_float_float_float_float_float_float_fl // - name: Dot_77_0 type: float shape: Shape{1, 3797} extern "C" __global__ void Dot_float_float_float_cuda_Dot_77(float* __restrict__ input0, float* __restrict__ input1, float* __restrict__ output0) { + const int num_tasks = 3797; + int block_start = 16 * 256 * blockIdx.x; + int warp_id = threadIdx.x >> 5; + int lane_id = threadIdx.x & 31; + int task_id_in_block = threadIdx.x >> 4; + int in_task_id = threadIdx.x & 15; __shared__ float share_a[256]; - __shared__ float share_b[2304]; - { - { - const int num_tasks = 3797; - int block_start = 64 * 256 * blockIdx.x; - int warp_id = threadIdx.x >> 5; - int lane_id = threadIdx.x & 31; - int task_id_in_block = threadIdx.x >> 2; - int in_task_id = threadIdx.x & 3; - - // 36 * 64 - share_a[threadIdx.x] = input0[threadIdx.x]; - float s = 0.0; - for (int k = 0; k < 256; k += 32) { - #pragma unroll - for (int i = warp_id; i < 64; i += 8) if (blockIdx.x * 64 + i < num_tasks) share_b[i * 36 + lane_id] = input1[block_start + i * 256 + k + lane_id]; - __syncthreads(); - // if (threadIdx.x == 0) { printf("shareb k=%d:", k); for (int i = 0; i < 32; i++) printf("%f ", share_b[i]); printf("\n");} - #pragma unroll - for (int j = in_task_id; j < 32; j += 4) s += share_a[k + j] * share_b[task_id_in_block * 36 + j]; - __syncthreads(); - } - s += __shfl_xor_sync(0xffffffff, s, 2); - s += __shfl_xor_sync(0xffffffff, s, 1); - if (in_task_id == 0 && blockIdx.x * 64 + task_id_in_block < num_tasks) output0[blockIdx.x * 64 + task_id_in_block] = s; - } - + __shared__ float share_b[768]; // 36 * 64 + share_a[threadIdx.x] = input0[threadIdx.x]; + float s = 0.0; + for (int k = 0; k < 256; k += 32) { + #pragma unroll + for (int i = warp_id; i < 16; i += 8) if (blockIdx.x * 16 + i < num_tasks) share_b[i * 48 + lane_id] = input1[block_start + i * 256 + k + lane_id]; + __syncthreads(); + // if (threadIdx.x == 0) { printf("shareb k=%d:", k); for (int i = 0; i < 32; i++) printf("%f ", share_b[i]); printf("\n");} + #pragma unroll + for (int j = in_task_id; j < 32; j += 16) s += share_a[k + j] * share_b[task_id_in_block * 48 + j]; + __syncthreads(); } - + s += __shfl_xor_sync(0xffffffff, s, 8); + s += __shfl_xor_sync(0xffffffff, s, 4); + s += __shfl_xor_sync(0xffffffff, s, 2); + s += __shfl_xor_sync(0xffffffff, s, 1); + if (in_task_id == 0 && blockIdx.x * 16 + task_id_in_block < num_tasks) output0[blockIdx.x * 16 + task_id_in_block] = s; } extern void Dot_float_float_float_cuda_Dot_77_Call(const dim3 &grids, const dim3 &blocks, unsigned mem, cudaStream_t stream, float* __restrict__ input0, float* __restrict__ input1, float* __restrict__ output0) { Dot_float_float_float_cuda_Dot_77<<>>(input0, input1, output0); @@ -1546,7 +1540,7 @@ FusedKernel_float_float_float_float_float_float_float_float_float_float_float_fl // name=@tmp_33 Reshape_float_float_cuda_Reshape_21_Call(dim3(1, 1, 1), dim3(256, 1, 1), 0, 0, Result_96_0, Reshape_76_0); // name=Dot_77 -Dot_float_float_float_cuda_Dot_77_Call(dim3(64, 1, 1), dim3(256, 1, 1), 0, 0, Reshape_76_0, Constant_11_0, Dot_77_0); +Dot_float_float_float_cuda_Dot_77_Call(dim3(238, 1, 1), dim3(256, 1, 1), 0, 0, Reshape_76_0, Constant_11_0, Dot_77_0); // name=ElementWiseFused_103 FusedKernel_float_float_float_float_cuda_Add_Reshape_Add_1_Call(dim3(15, 1, 1), dim3(256, 1, 1), 0, 0, Dot_77_0, Reshape_78_0, GatherV2_82_0, Add_83_0); // name=output_2 diff --git a/artifacts/models/manual_seq2seq/bs64/run.py b/artifacts/models/manual_seq2seq/bs64/run.py index d238f8b4e..c86dca5cb 100644 --- a/artifacts/models/manual_seq2seq/bs64/run.py +++ b/artifacts/models/manual_seq2seq/bs64/run.py @@ -184,7 +184,7 @@ def read_bin(s, dtype=np.float32): mask = masks[:, i:i+batch_size].contiguous() torch.cuda.synchronize() output_all, h = model(encoder_output, mask, h, c) - if i == 0: print(output_all) + if i == 0: print(output_all, h) torch.cuda.synchronize() # run timer = Timer("ms") diff --git a/artifacts/models/manual_seq2seq/bs64/seq2seq_bs64_0-forward/nnfusion_rt/cuda_codegen/nnfusion_rt.cu b/artifacts/models/manual_seq2seq/bs64/seq2seq_bs64_0-forward/nnfusion_rt/cuda_codegen/nnfusion_rt.cu index d885787bd..8263f4465 100644 --- a/artifacts/models/manual_seq2seq/bs64/seq2seq_bs64_0-forward/nnfusion_rt/cuda_codegen/nnfusion_rt.cu +++ b/artifacts/models/manual_seq2seq/bs64/seq2seq_bs64_0-forward/nnfusion_rt/cuda_codegen/nnfusion_rt.cu @@ -602,176 +602,60 @@ void Result_int64_t_int64_t_cuda_lib_Result_101(cudaStream_t stream, int64_t* in __device__ __noinline__ void BatchMatMul_float_float_float_cuda_BatchMatMul_28_block_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ compute, int thread_id, int block_id, char *shared_buffer) { if (thread_id >= 256){ - for (int i = 0; i < 16; i++) __syncthreads(); + for (int i = 0; i < 4; i++) __syncthreads(); return; } const dim3 blockDim(256, 1, 1); - const dim3 gridDim(8, 1, 1); + const dim3 gridDim(64, 1, 1); const dim3 blockIdx(block_id, 0, 0); float* A_shared = (float*)(shared_buffer + 0); - float* B_shared = (float*)(shared_buffer + 8192); + float* B_shared = (float*)(shared_buffer + 16384); { { - float compute_local[32]; - - - float A_shared_local[4]; - float B_shared_local[8]; - compute_local[(0)] = 0.000000e+00f; - compute_local[(8)] = 0.000000e+00f; - compute_local[(16)] = 0.000000e+00f; - compute_local[(24)] = 0.000000e+00f; - compute_local[(2)] = 0.000000e+00f; - compute_local[(10)] = 0.000000e+00f; - compute_local[(18)] = 0.000000e+00f; - compute_local[(26)] = 0.000000e+00f; - compute_local[(4)] = 0.000000e+00f; - compute_local[(12)] = 0.000000e+00f; - compute_local[(20)] = 0.000000e+00f; - compute_local[(28)] = 0.000000e+00f; - compute_local[(6)] = 0.000000e+00f; - compute_local[(14)] = 0.000000e+00f; - compute_local[(22)] = 0.000000e+00f; - compute_local[(30)] = 0.000000e+00f; - compute_local[(1)] = 0.000000e+00f; - compute_local[(9)] = 0.000000e+00f; - compute_local[(17)] = 0.000000e+00f; - compute_local[(25)] = 0.000000e+00f; - compute_local[(3)] = 0.000000e+00f; - compute_local[(11)] = 0.000000e+00f; - compute_local[(19)] = 0.000000e+00f; - compute_local[(27)] = 0.000000e+00f; - compute_local[(5)] = 0.000000e+00f; - compute_local[(13)] = 0.000000e+00f; - compute_local[(21)] = 0.000000e+00f; - compute_local[(29)] = 0.000000e+00f; - compute_local[(7)] = 0.000000e+00f; - compute_local[(15)] = 0.000000e+00f; - compute_local[(23)] = 0.000000e+00f; - compute_local[(31)] = 0.000000e+00f; - for (int k_outer = 0; k_outer < 8; ++k_outer) { - __syncthreads(); - A_shared[(((int)threadIdx.x))] = A[(((((((int)threadIdx.x) >> 5) * 256) + (k_outer * 32)) + (((int)threadIdx.x) & 31)))]; - A_shared[((((int)threadIdx.x) + 256))] = A[((((((((int)threadIdx.x) >> 5) * 256) + (k_outer * 32)) + (((int)threadIdx.x) & 31)) + 2048))]; - A_shared[((((int)threadIdx.x) + 512))] = A[((((((((int)threadIdx.x) >> 5) * 256) + (k_outer * 32)) + (((int)threadIdx.x) & 31)) + 4096))]; - A_shared[((((int)threadIdx.x) + 768))] = A[((((((((int)threadIdx.x) >> 5) * 256) + (k_outer * 32)) + (((int)threadIdx.x) & 31)) + 6144))]; - A_shared[((((int)threadIdx.x) + 1024))] = A[((((((((int)threadIdx.x) >> 5) * 256) + (k_outer * 32)) + (((int)threadIdx.x) & 31)) + 8192))]; - A_shared[((((int)threadIdx.x) + 1280))] = A[((((((((int)threadIdx.x) >> 5) * 256) + (k_outer * 32)) + (((int)threadIdx.x) & 31)) + 10240))]; - A_shared[((((int)threadIdx.x) + 1536))] = A[((((((((int)threadIdx.x) >> 5) * 256) + (k_outer * 32)) + (((int)threadIdx.x) & 31)) + 12288))]; - A_shared[((((int)threadIdx.x) + 1792))] = A[((((((((int)threadIdx.x) >> 5) * 256) + (k_outer * 32)) + (((int)threadIdx.x) & 31)) + 14336))]; - B_shared[(((int)threadIdx.x))] = B[(((((k_outer * 8192) + ((((int)threadIdx.x) >> 5) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)))]; - B_shared[((((int)threadIdx.x) + 256))] = B[((((((k_outer * 8192) + ((((int)threadIdx.x) >> 5) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)) + 2048))]; - B_shared[((((int)threadIdx.x) + 512))] = B[((((((k_outer * 8192) + ((((int)threadIdx.x) >> 5) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)) + 4096))]; - B_shared[((((int)threadIdx.x) + 768))] = B[((((((k_outer * 8192) + ((((int)threadIdx.x) >> 5) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)) + 6144))]; - B_shared[((((int)threadIdx.x) + 1024))] = B[((((((k_outer * 8192) + ((((int)threadIdx.x) >> 5) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)) + 65536))]; - B_shared[((((int)threadIdx.x) + 1280))] = B[((((((((((int)threadIdx.x) + 1280) >> 10) * 65536) + (k_outer * 8192)) + (((((int)threadIdx.x) >> 5) + 8) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)))]; - B_shared[((((int)threadIdx.x) + 1536))] = B[((((((((((int)threadIdx.x) + 1536) >> 10) * 65536) + (k_outer * 8192)) + (((((int)threadIdx.x) >> 5) + 16) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)))]; - B_shared[((((int)threadIdx.x) + 1792))] = B[((((((((((int)threadIdx.x) + 1792) >> 10) * 65536) + (k_outer * 8192)) + (((((int)threadIdx.x) >> 5) + 24) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)))]; - B_shared[((((int)threadIdx.x) + 2048))] = B[((((((k_outer * 8192) + ((((int)threadIdx.x) >> 5) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)) + 131072))]; - B_shared[((((int)threadIdx.x) + 2304))] = B[((((((((((int)threadIdx.x) + 2304) >> 10) * 65536) + (k_outer * 8192)) + (((((int)threadIdx.x) >> 5) + 8) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)))]; - B_shared[((((int)threadIdx.x) + 2560))] = B[((((((((((int)threadIdx.x) + 2560) >> 10) * 65536) + (k_outer * 8192)) + (((((int)threadIdx.x) >> 5) + 16) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)))]; - B_shared[((((int)threadIdx.x) + 2816))] = B[((((((((((int)threadIdx.x) + 2816) >> 10) * 65536) + (k_outer * 8192)) + (((((int)threadIdx.x) >> 5) + 24) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)))]; - B_shared[((((int)threadIdx.x) + 3072))] = B[((((((k_outer * 8192) + ((((int)threadIdx.x) >> 5) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)) + 196608))]; - B_shared[((((int)threadIdx.x) + 3328))] = B[((((((((((int)threadIdx.x) + 3328) >> 10) * 65536) + (k_outer * 8192)) + (((((int)threadIdx.x) >> 5) + 8) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)))]; - B_shared[((((int)threadIdx.x) + 3584))] = B[((((((((((int)threadIdx.x) + 3584) >> 10) * 65536) + (k_outer * 8192)) + (((((int)threadIdx.x) >> 5) + 16) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)))]; - B_shared[((((int)threadIdx.x) + 3840))] = B[((((((((((int)threadIdx.x) + 3840) >> 10) * 65536) + (k_outer * 8192)) + (((((int)threadIdx.x) >> 5) + 24) * 256)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 31)))]; - __syncthreads(); - for (int k_inner_outer = 0; k_inner_outer < 32; ++k_inner_outer) { - A_shared_local[(0)] = A_shared[((((((int)threadIdx.x) >> 4) * 32) + k_inner_outer))]; - A_shared_local[(1)] = A_shared[(((((((int)threadIdx.x) >> 4) * 32) + k_inner_outer) + 512))]; - A_shared_local[(2)] = A_shared[(((((((int)threadIdx.x) >> 4) * 32) + k_inner_outer) + 1024))]; - A_shared_local[(3)] = A_shared[(((((((int)threadIdx.x) >> 4) * 32) + k_inner_outer) + 1536))]; - B_shared_local[(0)] = B_shared[(((k_inner_outer * 32) + (((int)threadIdx.x) & 15)))]; - B_shared_local[(2)] = B_shared[((((k_inner_outer * 32) + (((int)threadIdx.x) & 15)) + 1024))]; - B_shared_local[(4)] = B_shared[((((k_inner_outer * 32) + (((int)threadIdx.x) & 15)) + 2048))]; - B_shared_local[(6)] = B_shared[((((k_inner_outer * 32) + (((int)threadIdx.x) & 15)) + 3072))]; - B_shared_local[(1)] = B_shared[((((k_inner_outer * 32) + (((int)threadIdx.x) & 15)) + 16))]; - B_shared_local[(3)] = B_shared[((((k_inner_outer * 32) + (((int)threadIdx.x) & 15)) + 1040))]; - B_shared_local[(5)] = B_shared[((((k_inner_outer * 32) + (((int)threadIdx.x) & 15)) + 2064))]; - B_shared_local[(7)] = B_shared[((((k_inner_outer * 32) + (((int)threadIdx.x) & 15)) + 3088))]; - compute_local[(0)] = (compute_local[(0)] + (A_shared_local[(0)] * B_shared_local[(0)])); - compute_local[(8)] = (compute_local[(8)] + (A_shared_local[(0)] * B_shared_local[(2)])); - compute_local[(16)] = (compute_local[(16)] + (A_shared_local[(0)] * B_shared_local[(4)])); - compute_local[(24)] = (compute_local[(24)] + (A_shared_local[(0)] * B_shared_local[(6)])); - compute_local[(2)] = (compute_local[(2)] + (A_shared_local[(1)] * B_shared_local[(0)])); - compute_local[(10)] = (compute_local[(10)] + (A_shared_local[(1)] * B_shared_local[(2)])); - compute_local[(18)] = (compute_local[(18)] + (A_shared_local[(1)] * B_shared_local[(4)])); - compute_local[(26)] = (compute_local[(26)] + (A_shared_local[(1)] * B_shared_local[(6)])); - compute_local[(4)] = (compute_local[(4)] + (A_shared_local[(2)] * B_shared_local[(0)])); - compute_local[(12)] = (compute_local[(12)] + (A_shared_local[(2)] * B_shared_local[(2)])); - compute_local[(20)] = (compute_local[(20)] + (A_shared_local[(2)] * B_shared_local[(4)])); - compute_local[(28)] = (compute_local[(28)] + (A_shared_local[(2)] * B_shared_local[(6)])); - compute_local[(6)] = (compute_local[(6)] + (A_shared_local[(3)] * B_shared_local[(0)])); - compute_local[(14)] = (compute_local[(14)] + (A_shared_local[(3)] * B_shared_local[(2)])); - compute_local[(22)] = (compute_local[(22)] + (A_shared_local[(3)] * B_shared_local[(4)])); - compute_local[(30)] = (compute_local[(30)] + (A_shared_local[(3)] * B_shared_local[(6)])); - compute_local[(1)] = (compute_local[(1)] + (A_shared_local[(0)] * B_shared_local[(1)])); - compute_local[(9)] = (compute_local[(9)] + (A_shared_local[(0)] * B_shared_local[(3)])); - compute_local[(17)] = (compute_local[(17)] + (A_shared_local[(0)] * B_shared_local[(5)])); - compute_local[(25)] = (compute_local[(25)] + (A_shared_local[(0)] * B_shared_local[(7)])); - compute_local[(3)] = (compute_local[(3)] + (A_shared_local[(1)] * B_shared_local[(1)])); - compute_local[(11)] = (compute_local[(11)] + (A_shared_local[(1)] * B_shared_local[(3)])); - compute_local[(19)] = (compute_local[(19)] + (A_shared_local[(1)] * B_shared_local[(5)])); - compute_local[(27)] = (compute_local[(27)] + (A_shared_local[(1)] * B_shared_local[(7)])); - compute_local[(5)] = (compute_local[(5)] + (A_shared_local[(2)] * B_shared_local[(1)])); - compute_local[(13)] = (compute_local[(13)] + (A_shared_local[(2)] * B_shared_local[(3)])); - compute_local[(21)] = (compute_local[(21)] + (A_shared_local[(2)] * B_shared_local[(5)])); - compute_local[(29)] = (compute_local[(29)] + (A_shared_local[(2)] * B_shared_local[(7)])); - compute_local[(7)] = (compute_local[(7)] + (A_shared_local[(3)] * B_shared_local[(1)])); - compute_local[(15)] = (compute_local[(15)] + (A_shared_local[(3)] * B_shared_local[(3)])); - compute_local[(23)] = (compute_local[(23)] + (A_shared_local[(3)] * B_shared_local[(5)])); - compute_local[(31)] = (compute_local[(31)] + (A_shared_local[(3)] * B_shared_local[(7)])); + float compute_local[4]; + compute_local[(0)] = 0.000000e+00f; + compute_local[(2)] = 0.000000e+00f; + compute_local[(1)] = 0.000000e+00f; + compute_local[(3)] = 0.000000e+00f; + for (int k_outer_outer = 0; k_outer_outer < 2; ++k_outer_outer) { + __syncthreads(); + ((float4*)(A_shared + ((((int)threadIdx.x) * 4))))[0] = ((float4*)(A + (((((((((int)blockIdx.x) & 31) >> 4) * 8192) + ((((int)threadIdx.x) >> 5) * 256)) + (k_outer_outer * 128)) + ((((int)threadIdx.x) & 31) * 4)))))[0]; + ((float4*)(A_shared + (((((int)threadIdx.x) * 4) + 1024))))[0] = ((float4*)(A + ((((((((((int)blockIdx.x) & 31) >> 4) * 8192) + ((((int)threadIdx.x) >> 5) * 256)) + (k_outer_outer * 128)) + ((((int)threadIdx.x) & 31) * 4)) + 2048))))[0]; + ((float4*)(A_shared + (((((int)threadIdx.x) * 4) + 2048))))[0] = ((float4*)(A + ((((((((((int)blockIdx.x) & 31) >> 4) * 8192) + ((((int)threadIdx.x) >> 5) * 256)) + (k_outer_outer * 128)) + ((((int)threadIdx.x) & 31) * 4)) + 4096))))[0]; + ((float4*)(A_shared + (((((int)threadIdx.x) * 4) + 3072))))[0] = ((float4*)(A + ((((((((((int)blockIdx.x) & 31) >> 4) * 8192) + ((((int)threadIdx.x) >> 5) * 256)) + (k_outer_outer * 128)) + ((((int)threadIdx.x) & 31) * 4)) + 6144))))[0]; + ((float4*)(B_shared + ((((int)threadIdx.x) * 4))))[0] = ((float4*)(B + (((((((((int)blockIdx.x) >> 5) * 131072) + (k_outer_outer * 32768)) + ((((int)threadIdx.x) >> 2) * 256)) + ((((int)blockIdx.x) & 15) * 16)) + ((((int)threadIdx.x) & 3) * 4)))))[0]; + ((float4*)(B_shared + (((((int)threadIdx.x) * 4) + 1024))))[0] = ((float4*)(B + ((((((((((int)blockIdx.x) >> 5) * 131072) + (k_outer_outer * 32768)) + ((((int)threadIdx.x) >> 2) * 256)) + ((((int)blockIdx.x) & 15) * 16)) + ((((int)threadIdx.x) & 3) * 4)) + 16384))))[0]; + ((float4*)(B_shared + (((((int)threadIdx.x) * 4) + 2048))))[0] = ((float4*)(B + ((((((((((int)blockIdx.x) >> 5) * 131072) + (k_outer_outer * 32768)) + ((((int)threadIdx.x) >> 2) * 256)) + ((((int)blockIdx.x) & 15) * 16)) + ((((int)threadIdx.x) & 3) * 4)) + 65536))))[0]; + ((float4*)(B_shared + (((((((((int)threadIdx.x) * 4) + 3072) >> 11) * 2048) + (((((int)threadIdx.x) >> 2) + 64) * 16)) + ((((int)threadIdx.x) & 3) * 4)))))[0] = ((float4*)(B + ((((((((((int)blockIdx.x) >> 5) * 131072) + ((((((int)threadIdx.x) * 4) + 3072) >> 11) * 65536)) + (k_outer_outer * 32768)) + (((((int)threadIdx.x) >> 2) + 64) * 256)) + ((((int)blockIdx.x) & 15) * 16)) + ((((int)threadIdx.x) & 3) * 4)))))[0]; + __syncthreads(); + for (int k_outer_inner = 0; k_outer_inner < 2; ++k_outer_inner) { + for (int k_inner = 0; k_inner < 64; ++k_inner) { + compute_local[(0)] = (compute_local[(0)] + (A_shared[(((((((int)threadIdx.x) >> 4) * 256) + (k_outer_inner * 64)) + k_inner))] * B_shared[((((k_outer_inner * 1024) + (k_inner * 16)) + (((int)threadIdx.x) & 15)))])); + compute_local[(2)] = (compute_local[(2)] + (A_shared[(((((((int)threadIdx.x) >> 4) * 256) + (k_outer_inner * 64)) + k_inner))] * B_shared[(((((k_outer_inner * 1024) + (k_inner * 16)) + (((int)threadIdx.x) & 15)) + 2048))])); + compute_local[(1)] = (compute_local[(1)] + (A_shared[((((((((int)threadIdx.x) >> 4) * 256) + (k_outer_inner * 64)) + k_inner) + 128))] * B_shared[((((k_outer_inner * 1024) + (k_inner * 16)) + (((int)threadIdx.x) & 15)))])); + compute_local[(3)] = (compute_local[(3)] + (A_shared[((((((((int)threadIdx.x) >> 4) * 256) + (k_outer_inner * 64)) + k_inner) + 128))] * B_shared[(((((k_outer_inner * 1024) + (k_inner * 16)) + (((int)threadIdx.x) & 15)) + 2048))])); + } + } + } + for (int x_inner = 0; x_inner < 2; ++x_inner) { + compute[((((((((((int)blockIdx.x) >> 5) * 32768) + (((((int)blockIdx.x) & 31) >> 4) * 8192)) + ((((int)threadIdx.x) >> 4) * 512)) + (x_inner * 256)) + ((((int)blockIdx.x) & 15) * 16)) + (((int)threadIdx.x) & 15)))] = compute_local[(x_inner)]; + compute[(((((((((((int)blockIdx.x) >> 5) * 32768) + (((((int)blockIdx.x) & 31) >> 4) * 8192)) + ((((int)threadIdx.x) >> 4) * 512)) + (x_inner * 256)) + ((((int)blockIdx.x) & 15) * 16)) + (((int)threadIdx.x) & 15)) + 16384))] = compute_local[((x_inner + 2))]; } - } - compute[(((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)))] = compute_local[(0)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 16384))] = compute_local[(8)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 32768))] = compute_local[(16)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 49152))] = compute_local[(24)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 4096))] = compute_local[(2)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 20480))] = compute_local[(10)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 36864))] = compute_local[(18)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 53248))] = compute_local[(26)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 8192))] = compute_local[(4)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 24576))] = compute_local[(12)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 40960))] = compute_local[(20)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 57344))] = compute_local[(28)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 12288))] = compute_local[(6)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 28672))] = compute_local[(14)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 45056))] = compute_local[(22)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 61440))] = compute_local[(30)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 16))] = compute_local[(1)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 16400))] = compute_local[(9)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 32784))] = compute_local[(17)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 49168))] = compute_local[(25)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 4112))] = compute_local[(3)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 20496))] = compute_local[(11)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 36880))] = compute_local[(19)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 53264))] = compute_local[(27)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 8208))] = compute_local[(5)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 24592))] = compute_local[(13)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 40976))] = compute_local[(21)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 57360))] = compute_local[(29)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 12304))] = compute_local[(7)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 28688))] = compute_local[(15)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 45072))] = compute_local[(23)]; - compute[((((((((int)threadIdx.x) >> 4) * 256) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) & 15)) + 61456))] = compute_local[(31)]; } - - } } extern "C" __global__ void BlockFusionKernel_float_float_float_float_float_float_cuda_BatchMatMul_BatchMatMul_1(float* input0, float* input1, float* input2, float* input3, float* output0, float* output1) { - __shared__ char shared_buffer[24576]; + __shared__ char shared_buffer[32768]; - if ((int)blockIdx.x >= 0 && (int)blockIdx.x <= 7) + if ((int)blockIdx.x >= 0 && (int)blockIdx.x <= 63) { BatchMatMul_float_float_float_cuda_BatchMatMul_28_block_kernel(input0, input1, output0, threadIdx.x, blockIdx.x - 0 + 0, shared_buffer); } - else if ((int)blockIdx.x >= 8 && (int)blockIdx.x <= 15) + else if ((int)blockIdx.x >= 64 && (int)blockIdx.x <= 127) { - BatchMatMul_float_float_float_cuda_BatchMatMul_28_block_kernel(input2, input3, output1, threadIdx.x, blockIdx.x - 8 + 0, shared_buffer); + BatchMatMul_float_float_float_cuda_BatchMatMul_28_block_kernel(input2, input3, output1, threadIdx.x, blockIdx.x - 64 + 0, shared_buffer); } } @@ -1817,7 +1701,7 @@ Reshape_float_float_cuda_Reshape_21_Call(dim3(64, 1, 1), dim3(256, 1, 1), 0, 0, // name=Reshape_26 // eliminated: Reshape_float_float_cuda_Reshape_21_Call(dim3(64, 1, 1), dim3(256, 1, 1), 0, 0, Reshape_21_0, Reshape_26_0); // name=blockfusion_kernel_105 -BlockFusionKernel_float_float_float_float_float_float_cuda_BatchMatMul_BatchMatMul_1_Call(dim3(16, 1, 1), dim3(256, 1, 1), 0, 0, Reshape_26_0, Reshape_27_0, Reshape_22_0, Reshape_23_0, BatchMatMul_28_0, BatchMatMul_24_0); +BlockFusionKernel_float_float_float_float_float_float_cuda_BatchMatMul_BatchMatMul_1_Call(dim3(128, 1, 1), dim3(256, 1, 1), 0, 0, Reshape_26_0, Reshape_27_0, Reshape_22_0, Reshape_23_0, BatchMatMul_28_0, BatchMatMul_24_0); // name=Reshape_25 // eliminated: Reshape_float_float_cuda_Reshape_25_Call(dim3(256, 1, 1), dim3(256, 1, 1), 0, 0, BatchMatMul_24_0, Reshape_25_0); // name=Reshape_29