Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TL] Add example usage/test case for Dynamic Symbolic #191

Merged
merged 67 commits into from
Sep 26, 2024

Conversation

LeiWang1999
Copy link
Contributor

New Test File:

  • Added a new test file testing/python/tilelang/test_tilelang_dyanmic_symbolic.py to test dynamic symbolic operations using the TensorCoreIntrinEmitter class.

Example Usage

Describe Computation with Dynamic Symbolic.

M = tvm.te.var("m")
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)

import tvm.tl.language as T

@T.prim_func
def main(A: T.Buffer(A_shape, dtypeAB), B: T.Buffer(B_shape, dtypeAB), C: T.Buffer((M, N),
                                                                                    dtypeC)):
    with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
        A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
        B_shared = T.alloc_shared(B_shared_shape, dtypeAB)
        C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
        T.clear(C_local)
        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
            T.copy(A[by * block_M, k * block_K], A_shared)
            T.copy(B[bx * block_N, k * block_K], B_shared)
            T.gemm(A_shared, B_shared, C_local, trans_B=True)
        T.copy(C_local, C[by * block_M, bx * block_N])

Below is the generated kernel :

#include <tl_templates/gemm.h>
#include <tl_templates/copy.h>
#include <tl_templates/reduce.h>
#include <tl_templates/ldsm.h>
#include <tl_templates/threadblock_swizzle.h>

extern "C" __global__ void __launch_bounds__(128) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C, int m) {
  extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
  half_t C_local[32];
  #pragma unroll
  for (int i = 0; i < 16; ++i) {
    *(uint1*)(C_local + (i * 2)) = make_uint1(__pack_half2(half_t(0.000000e+00f), half_t(0.000000e+00f)));
  }
  #pragma unroll
  for (int i_1 = 0; i_1 < 2; ++i_1) {
    tl::cp_async_gs_conditional<16>(buf_dyn_shmem+((((i_1 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)), A+((((((int)blockIdx.y) * 8192) + (i_1 * 4096)) + ((((int)threadIdx.x) >> 2) * 128)) + ((((int)threadIdx.x) & 3) * 8)), ((((((int)blockIdx.y) * 64) + (i_1 * 32)) + (((int)threadIdx.x) >> 2)) < m));
  }
  #pragma unroll
  for (int i_2 = 0; i_2 < 2; ++i_2) {
    tl::cp_async_gs<16>(buf_dyn_shmem+((((((i_2 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 12288), B+((((i_2 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)));
  }
  tl::cp_async_commit();
  #pragma unroll
  for (int i_3 = 0; i_3 < 2; ++i_3) {
    tl::cp_async_gs_conditional<16>(buf_dyn_shmem+(((((i_3 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 4096), A+(((((((int)blockIdx.y) * 8192) + (i_3 * 4096)) + ((((int)threadIdx.x) >> 2) * 128)) + ((((int)threadIdx.x) & 3) * 8)) + 32), ((((((int)blockIdx.y) * 64) + (i_3 * 32)) + (((int)threadIdx.x) >> 2)) < m));
  }
  #pragma unroll
  for (int i_4 = 0; i_4 < 2; ++i_4) {
    tl::cp_async_gs<16>(buf_dyn_shmem+((((((i_4 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 16384), B+(((((i_4 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 4096));
  }
  tl::cp_async_commit();
  for (int k = 0; k < 2; ++k) {
    __syncthreads();
    #pragma unroll
    for (int i_5 = 0; i_5 < 2; ++i_5) {
      tl::cp_async_gs_conditional<16>(buf_dyn_shmem+(((((((k + 2) % 3) * 4096) + (i_5 * 2048)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)), A+((((((((int)blockIdx.y) * 8192) + (i_5 * 4096)) + ((((int)threadIdx.x) >> 2) * 128)) + (k * 32)) + ((((int)threadIdx.x) & 3) * 8)) + 64), ((((((int)blockIdx.y) * 64) + (i_5 * 32)) + (((int)threadIdx.x) >> 2)) < m));
    }
    #pragma unroll
    for (int i_6 = 0; i_6 < 2; ++i_6) {
      tl::cp_async_gs<16>(buf_dyn_shmem+(((((((((k + 2) % 3) * 4096) + (i_6 * 2048)) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 12288), B+((((((k * 4096) + (i_6 * 2048)) + ((((int)threadIdx.x) >> 3) * 128)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 8192));
    }
    tl::cp_async_commit();
    tl::cp_async_wait<2>();
    __syncthreads();
    tl::gemm_ss<64, 64, 32, 2, 2, 0, 0>((&(((half_t*)buf_dyn_shmem)[(k * 2048)])), (&(((half_t*)buf_dyn_shmem)[((k * 2048) + 6144)])), (&(C_local[0])));
  }
  tl::cp_async_wait<1>();
  __syncthreads();
  tl::gemm_ss<64, 64, 32, 2, 2, 0, 0>((&(((half_t*)buf_dyn_shmem)[4096])), (&(((half_t*)buf_dyn_shmem)[10240])), (&(C_local[0])));
  tl::cp_async_wait<0>();
  __syncthreads();
  tl::gemm_ss<64, 64, 32, 2, 2, 0, 0>((&(((half_t*)buf_dyn_shmem)[0])), (&(((half_t*)buf_dyn_shmem)[6144])), (&(C_local[0])));
  #pragma unroll
  for (int i_7 = 0; i_7 < 16; ++i_7) {
    if ((((((((int)blockIdx.y) * 64) + (((i_7 & 3) >> 1) * 32)) + (((((int)threadIdx.x) & 63) >> 5) * 16)) + ((i_7 & 1) * 8)) + ((((int)threadIdx.x) & 31) >> 2)) < m) {
      *(uint1*)(C + (((((((((((int)blockIdx.y) * 8192) + (((i_7 & 3) >> 1) * 4096)) + (((((int)threadIdx.x) & 63) >> 5) * 2048)) + ((i_7 & 1) * 1024)) + (((((int)threadIdx.x) & 31) >> 2) * 128)) + (((int)blockIdx.x) * 64)) + ((i_7 >> 2) * 16)) + ((((int)threadIdx.x) >> 6) * 8)) + ((((int)threadIdx.x) & 3) * 2))) = *(uint1*)(C_local + (i_7 * 2));
    }
  }
}

@LeiWang1999
Copy link
Contributor Author

Tail split is required when we apply dynamic symbolic into the last dynamic, otherwise the vector load will be blocked

#include <tl_templates/gemm.h>
#include <tl_templates/copy.h>
#include <tl_templates/reduce.h>
#include <tl_templates/ldsm.h>
#include <tl_templates/threadblock_swizzle.h>

extern "C" __global__ void __launch_bounds__(128) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C, int k, int m, int n) {
  extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
  half_t C_local[32];
  #pragma unroll
  for (int i = 0; i < 16; ++i) {
    *(uint1*)(C_local + (i * 2)) = make_uint1(__pack_half2(half_t(0.000000e+00f), half_t(0.000000e+00f)));
  }
  if (0 < k) {
    #pragma unroll
    for (int i_1 = 0; i_1 < 16; ++i_1) {
      half_t condval;
      if ((((((((int)blockIdx.y) * 64) + (i_1 * 4)) + (((int)threadIdx.x) >> 5)) < m) && ((((int)threadIdx.x) & 31) < k))) {
        condval = A[(((((((int)blockIdx.y) * 64) + (i_1 * 4)) + (((int)threadIdx.x) >> 5)) * k) + (((int)threadIdx.x) & 31))];
      } else {
        condval = half_t(0.000000e+00f);
      }
      ((half_t*)buf_dyn_shmem)[(((((i_1 * 128) + ((((int)threadIdx.x) >> 5) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (i_1 & 1)) & 1) * 16)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 15) >> 3)) & 1) * 8)) + (((int)threadIdx.x) & 7))] = condval;
    }
  }
  if (0 < k) {
    #pragma unroll
    for (int i_2 = 0; i_2 < 16; ++i_2) {
      half_t condval_1;
      if (((((i_2 * 2) + (((int)threadIdx.x) >> 6)) < k) && (((((int)blockIdx.x) * 64) + (((int)threadIdx.x) & 63)) < n))) {
        condval_1 = B[(((((int)blockIdx.x) * 64) + (((i_2 * 2) + (((int)threadIdx.x) >> 6)) * n)) + (((int)threadIdx.x) & 63))];
      } else {
        condval_1 = half_t(0.000000e+00f);
      }
      ((half_t*)buf_dyn_shmem)[(((((((i_2 * 128) + ((((int)threadIdx.x) >> 6) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((i_2 & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (i_2 & 1)) & 1) * 16)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 15) >> 3)) & 1) * 8)) + (((int)threadIdx.x) & 7)) + 6144)] = condval_1;
    }
    tl::cp_async_commit();
  }
  if (32 < k) {
    #pragma unroll
    for (int i_3 = 0; i_3 < 16; ++i_3) {
      half_t condval_2;
      if ((((((((int)blockIdx.y) * 64) + (i_3 * 4)) + (((int)threadIdx.x) >> 5)) < m) && (((((int)threadIdx.x) & 31) + 32) < k))) {
        condval_2 = A[((((((((int)blockIdx.y) * 64) + (i_3 * 4)) + (((int)threadIdx.x) >> 5)) * k) + (((int)threadIdx.x) & 31)) + 32)];
      } else {
        condval_2 = half_t(0.000000e+00f);
      }
      ((half_t*)buf_dyn_shmem)[((((((i_3 * 128) + ((((int)threadIdx.x) >> 5) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (i_3 & 1)) & 1) * 16)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 15) >> 3)) & 1) * 8)) + (((int)threadIdx.x) & 7)) + 2048)] = condval_2;
    }
  }
  if (32 < k) {
    #pragma unroll
    for (int i_4 = 0; i_4 < 16; ++i_4) {
      half_t condval_3;
      if ((((((i_4 * 2) + (((int)threadIdx.x) >> 6)) + 32) < k) && (((((int)blockIdx.x) * 64) + (((int)threadIdx.x) & 63)) < n))) {
        condval_3 = B[(((((int)blockIdx.x) * 64) + ((((i_4 * 2) + (((int)threadIdx.x) >> 6)) + 32) * n)) + (((int)threadIdx.x) & 63))];
      } else {
        condval_3 = half_t(0.000000e+00f);
      }
      ((half_t*)buf_dyn_shmem)[(((((((i_4 * 128) + ((((int)threadIdx.x) >> 6) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((i_4 & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (i_4 & 1)) & 1) * 16)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 15) >> 3)) & 1) * 8)) + (((int)threadIdx.x) & 7)) + 8192)] = condval_3;
    }
    tl::cp_async_commit();
  }
  for (int k_1 = 0; k_1 < (((k + 31) >> 5) - 2); ++k_1) {
    __syncthreads();
    #pragma unroll
    for (int i_5 = 0; i_5 < 16; ++i_5) {
      half_t condval_4;
      if ((((((((int)blockIdx.y) * 64) + (i_5 * 4)) + (((int)threadIdx.x) >> 5)) < m) && ((((k_1 * 32) + (((int)threadIdx.x) & 31)) + 64) < k))) {
        condval_4 = A[((((k_1 * 32) + ((((((int)blockIdx.y) * 64) + (i_5 * 4)) + (((int)threadIdx.x) >> 5)) * k)) + (((int)threadIdx.x) & 31)) + 64)];
      } else {
        condval_4 = half_t(0.000000e+00f);
      }
      ((half_t*)buf_dyn_shmem)[((((((((k_1 + 2) % 3) * 2048) + (i_5 * 128)) + ((((int)threadIdx.x) >> 5) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (i_5 & 1)) & 1) * 16)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 15) >> 3)) & 1) * 8)) + (((int)threadIdx.x) & 7))] = condval_4;
    }
    #pragma unroll
    for (int i_6 = 0; i_6 < 16; ++i_6) {
      half_t condval_5;
      if (((((((k_1 * 32) + (i_6 * 2)) + (((int)threadIdx.x) >> 6)) + 64) < k) && (((((int)blockIdx.x) * 64) + (((int)threadIdx.x) & 63)) < n))) {
        condval_5 = B[((((((int)blockIdx.x) * 64) + (n * 64)) + ((((k_1 * 32) + (i_6 * 2)) + (((int)threadIdx.x) >> 6)) * n)) + (((int)threadIdx.x) & 63))];
      } else {
        condval_5 = half_t(0.000000e+00f);
      }
      ((half_t*)buf_dyn_shmem)[((((((((((k_1 + 2) % 3) * 2048) + (i_6 * 128)) + ((((int)threadIdx.x) >> 6) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((i_6 & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (i_6 & 1)) & 1) * 16)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 15) >> 3)) & 1) * 8)) + (((int)threadIdx.x) & 7)) + 6144)] = condval_5;
    }
    tl::cp_async_commit();
    tl::cp_async_wait<2>();
    __syncthreads();
    tl::gemm_ss<64, 64, 32, 2, 2, 0, 0>((&(((half_t*)buf_dyn_shmem)[((k_1 % 3) * 2048)])), (&(((half_t*)buf_dyn_shmem)[(((k_1 % 3) * 2048) + 6144)])), (&(C_local[0])));
  }
  if (33 <= k) {
    tl::cp_async_wait<1>();
    __syncthreads();
    tl::gemm_ss<64, 64, 32, 2, 2, 0, 0>((&(((half_t*)buf_dyn_shmem)[(((((k + 31) >> 5) + 1) % 3) * 2048)])), (&(((half_t*)buf_dyn_shmem)[((((((k + 31) >> 5) + 1) % 3) * 2048) + 6144)])), (&(C_local[0])));
  }
  if (1 <= k) {
    tl::cp_async_wait<0>();
    __syncthreads();
    tl::gemm_ss<64, 64, 32, 2, 2, 0, 0>((&(((half_t*)buf_dyn_shmem)[(((((k + 31) >> 5) + 2) % 3) * 2048)])), (&(((half_t*)buf_dyn_shmem)[((((((k + 31) >> 5) + 2) % 3) * 2048) + 6144)])), (&(C_local[0])));
  }
  #pragma unroll
  for (int i_7 = 0; i_7 < 32; ++i_7) {
    if (((((((((int)blockIdx.y) * 64) + (((i_7 & 7) >> 2) * 32)) + (((((int)threadIdx.x) & 63) >> 5) * 16)) + (((i_7 & 3) >> 1) * 8)) + ((((int)threadIdx.x) & 31) >> 2)) < m) && ((((((((int)blockIdx.x) * 64) + ((i_7 >> 3) * 16)) + ((((int)threadIdx.x) >> 6) * 8)) + ((((int)threadIdx.x) & 3) * 2)) + (i_7 & 1)) < n)) {
      C[((((((((int)blockIdx.x) * 64) + ((i_7 >> 3) * 16)) + ((((int)threadIdx.x) >> 6) * 8)) + ((((int)threadIdx.x) & 3) * 2)) + ((((((((int)blockIdx.y) * 64) + (((i_7 & 7) >> 2) * 32)) + (((((int)threadIdx.x) & 63) >> 5) * 16)) + (((i_7 & 3) >> 1) * 8)) + ((((int)threadIdx.x) & 31) >> 2)) * n)) + (i_7 & 1))] = C_local[i_7];
    }
  }
}

lead to significant drop in performance.

@LeiWang1999 LeiWang1999 merged commit 8d60542 into microsoft:main Sep 26, 2024
4 checks passed
@tzj-fxz
Copy link
Contributor

tzj-fxz commented Oct 15, 2024

I'm wondering about the meaning of tail split. Is this related to TensorCoreEmitter?

Tail split is required when we apply dynamic symbolic into the last dynamic, otherwise the vector load will be blocked

#include <tl_templates/gemm.h>
#include <tl_templates/copy.h>
#include <tl_templates/reduce.h>
#include <tl_templates/ldsm.h>
#include <tl_templates/threadblock_swizzle.h>

extern "C" __global__ void __launch_bounds__(128) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C, int k, int m, int n) {
  extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
  half_t C_local[32];
  #pragma unroll
  for (int i = 0; i < 16; ++i) {
    *(uint1*)(C_local + (i * 2)) = make_uint1(__pack_half2(half_t(0.000000e+00f), half_t(0.000000e+00f)));
  }
  if (0 < k) {
    #pragma unroll
    for (int i_1 = 0; i_1 < 16; ++i_1) {
      half_t condval;
      if ((((((((int)blockIdx.y) * 64) + (i_1 * 4)) + (((int)threadIdx.x) >> 5)) < m) && ((((int)threadIdx.x) & 31) < k))) {
        condval = A[(((((((int)blockIdx.y) * 64) + (i_1 * 4)) + (((int)threadIdx.x) >> 5)) * k) + (((int)threadIdx.x) & 31))];
      } else {
        condval = half_t(0.000000e+00f);
      }
      ((half_t*)buf_dyn_shmem)[(((((i_1 * 128) + ((((int)threadIdx.x) >> 5) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (i_1 & 1)) & 1) * 16)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 15) >> 3)) & 1) * 8)) + (((int)threadIdx.x) & 7))] = condval;
    }
  }
  if (0 < k) {
    #pragma unroll
    for (int i_2 = 0; i_2 < 16; ++i_2) {
      half_t condval_1;
      if (((((i_2 * 2) + (((int)threadIdx.x) >> 6)) < k) && (((((int)blockIdx.x) * 64) + (((int)threadIdx.x) & 63)) < n))) {
        condval_1 = B[(((((int)blockIdx.x) * 64) + (((i_2 * 2) + (((int)threadIdx.x) >> 6)) * n)) + (((int)threadIdx.x) & 63))];
      } else {
        condval_1 = half_t(0.000000e+00f);
      }
      ((half_t*)buf_dyn_shmem)[(((((((i_2 * 128) + ((((int)threadIdx.x) >> 6) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((i_2 & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (i_2 & 1)) & 1) * 16)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 15) >> 3)) & 1) * 8)) + (((int)threadIdx.x) & 7)) + 6144)] = condval_1;
    }
    tl::cp_async_commit();
  }
  if (32 < k) {
    #pragma unroll
    for (int i_3 = 0; i_3 < 16; ++i_3) {
      half_t condval_2;
      if ((((((((int)blockIdx.y) * 64) + (i_3 * 4)) + (((int)threadIdx.x) >> 5)) < m) && (((((int)threadIdx.x) & 31) + 32) < k))) {
        condval_2 = A[((((((((int)blockIdx.y) * 64) + (i_3 * 4)) + (((int)threadIdx.x) >> 5)) * k) + (((int)threadIdx.x) & 31)) + 32)];
      } else {
        condval_2 = half_t(0.000000e+00f);
      }
      ((half_t*)buf_dyn_shmem)[((((((i_3 * 128) + ((((int)threadIdx.x) >> 5) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (i_3 & 1)) & 1) * 16)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 15) >> 3)) & 1) * 8)) + (((int)threadIdx.x) & 7)) + 2048)] = condval_2;
    }
  }
  if (32 < k) {
    #pragma unroll
    for (int i_4 = 0; i_4 < 16; ++i_4) {
      half_t condval_3;
      if ((((((i_4 * 2) + (((int)threadIdx.x) >> 6)) + 32) < k) && (((((int)blockIdx.x) * 64) + (((int)threadIdx.x) & 63)) < n))) {
        condval_3 = B[(((((int)blockIdx.x) * 64) + ((((i_4 * 2) + (((int)threadIdx.x) >> 6)) + 32) * n)) + (((int)threadIdx.x) & 63))];
      } else {
        condval_3 = half_t(0.000000e+00f);
      }
      ((half_t*)buf_dyn_shmem)[(((((((i_4 * 128) + ((((int)threadIdx.x) >> 6) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((i_4 & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (i_4 & 1)) & 1) * 16)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 15) >> 3)) & 1) * 8)) + (((int)threadIdx.x) & 7)) + 8192)] = condval_3;
    }
    tl::cp_async_commit();
  }
  for (int k_1 = 0; k_1 < (((k + 31) >> 5) - 2); ++k_1) {
    __syncthreads();
    #pragma unroll
    for (int i_5 = 0; i_5 < 16; ++i_5) {
      half_t condval_4;
      if ((((((((int)blockIdx.y) * 64) + (i_5 * 4)) + (((int)threadIdx.x) >> 5)) < m) && ((((k_1 * 32) + (((int)threadIdx.x) & 31)) + 64) < k))) {
        condval_4 = A[((((k_1 * 32) + ((((((int)blockIdx.y) * 64) + (i_5 * 4)) + (((int)threadIdx.x) >> 5)) * k)) + (((int)threadIdx.x) & 31)) + 64)];
      } else {
        condval_4 = half_t(0.000000e+00f);
      }
      ((half_t*)buf_dyn_shmem)[((((((((k_1 + 2) % 3) * 2048) + (i_5 * 128)) + ((((int)threadIdx.x) >> 5) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (i_5 & 1)) & 1) * 16)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 15) >> 3)) & 1) * 8)) + (((int)threadIdx.x) & 7))] = condval_4;
    }
    #pragma unroll
    for (int i_6 = 0; i_6 < 16; ++i_6) {
      half_t condval_5;
      if (((((((k_1 * 32) + (i_6 * 2)) + (((int)threadIdx.x) >> 6)) + 64) < k) && (((((int)blockIdx.x) * 64) + (((int)threadIdx.x) & 63)) < n))) {
        condval_5 = B[((((((int)blockIdx.x) * 64) + (n * 64)) + ((((k_1 * 32) + (i_6 * 2)) + (((int)threadIdx.x) >> 6)) * n)) + (((int)threadIdx.x) & 63))];
      } else {
        condval_5 = half_t(0.000000e+00f);
      }
      ((half_t*)buf_dyn_shmem)[((((((((((k_1 + 2) % 3) * 2048) + (i_6 * 128)) + ((((int)threadIdx.x) >> 6) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((i_6 & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (i_6 & 1)) & 1) * 16)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 15) >> 3)) & 1) * 8)) + (((int)threadIdx.x) & 7)) + 6144)] = condval_5;
    }
    tl::cp_async_commit();
    tl::cp_async_wait<2>();
    __syncthreads();
    tl::gemm_ss<64, 64, 32, 2, 2, 0, 0>((&(((half_t*)buf_dyn_shmem)[((k_1 % 3) * 2048)])), (&(((half_t*)buf_dyn_shmem)[(((k_1 % 3) * 2048) + 6144)])), (&(C_local[0])));
  }
  if (33 <= k) {
    tl::cp_async_wait<1>();
    __syncthreads();
    tl::gemm_ss<64, 64, 32, 2, 2, 0, 0>((&(((half_t*)buf_dyn_shmem)[(((((k + 31) >> 5) + 1) % 3) * 2048)])), (&(((half_t*)buf_dyn_shmem)[((((((k + 31) >> 5) + 1) % 3) * 2048) + 6144)])), (&(C_local[0])));
  }
  if (1 <= k) {
    tl::cp_async_wait<0>();
    __syncthreads();
    tl::gemm_ss<64, 64, 32, 2, 2, 0, 0>((&(((half_t*)buf_dyn_shmem)[(((((k + 31) >> 5) + 2) % 3) * 2048)])), (&(((half_t*)buf_dyn_shmem)[((((((k + 31) >> 5) + 2) % 3) * 2048) + 6144)])), (&(C_local[0])));
  }
  #pragma unroll
  for (int i_7 = 0; i_7 < 32; ++i_7) {
    if (((((((((int)blockIdx.y) * 64) + (((i_7 & 7) >> 2) * 32)) + (((((int)threadIdx.x) & 63) >> 5) * 16)) + (((i_7 & 3) >> 1) * 8)) + ((((int)threadIdx.x) & 31) >> 2)) < m) && ((((((((int)blockIdx.x) * 64) + ((i_7 >> 3) * 16)) + ((((int)threadIdx.x) >> 6) * 8)) + ((((int)threadIdx.x) & 3) * 2)) + (i_7 & 1)) < n)) {
      C[((((((((int)blockIdx.x) * 64) + ((i_7 >> 3) * 16)) + ((((int)threadIdx.x) >> 6) * 8)) + ((((int)threadIdx.x) & 3) * 2)) + ((((((((int)blockIdx.y) * 64) + (((i_7 & 7) >> 2) * 32)) + (((((int)threadIdx.x) & 63) >> 5) * 16)) + (((i_7 & 3) >> 1) * 8)) + ((((int)threadIdx.x) & 31) >> 2)) * n)) + (i_7 & 1))] = C_local[i_7];
    }
  }
}

lead to significant drop in performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants