Skip to content

Commit

Permalink
gemm_ss
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Sep 4, 2024
1 parent 7bb21e7 commit 6a22442
Showing 1 changed file with 30 additions and 12 deletions.
42 changes: 30 additions & 12 deletions bitblas/tl/macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self._initialize_thread_axis(threads, self.WARP_SIZE, block_row_warps, block_col_warps)
self.threads = threads

def _initialize_k_dim(self, a_dtype="float16"):
self.k_dim = 256 // DataType(a_dtype).bits
Expand Down Expand Up @@ -91,17 +91,6 @@ def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16):
self.micro_size_y = n_dim
self.micro_size_k = k_dim

def _initialize_thread_axis(self,
threads=128,
warp_size=32,
block_row_warps=2,
block_col_warps=2):
self.threads = threads
# thread_bindings = T.env_thread("threadIdx.x")
# self.tx = thread_bindings % warp_size
# self.ty = (thread_bindings // warp_size) % block_row_warps
# self.tz = thread_bindings // (warp_size * block_row_warps)

@staticmethod
@T.macro
def MMA(inst, A_local_buf, B_local_buf, C_local_buf):
Expand Down Expand Up @@ -209,3 +198,32 @@ def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings):
C_shared_buf[ty * inst.warp_rows + i, tz * inst.warp_cols + j, row,
col] = C_local_buf[i * (inst.warp_cols * inst.local_size_out) +
j * inst.local_size_out + local_id]

# Allow GEMM from shared memory to local memory
@staticmethod
@T.macro
def GEMM_SS(inst, A_shared_buf, B_shared_buf, C_local_buf, thread_bindings):
A_local_buf = T.alloc_fragment((inst.warp_rows * inst.local_size),
inst.a_dtype,
scope="local")
B_local_buf = T.alloc_fragment((inst.warp_cols * inst.local_size),
inst.b_dtype,
scope="local")
for ki in T.serial(0, (inst.block_K // inst.micro_size_k)):
inst.LDMATRIX_A(
inst,
A_local_buf,
A_shared_buf,
ki,
thread_bindings=thread_bindings,
)

inst.LDMATRIX_B(
inst,
B_local_buf,
B_shared_buf,
ki,
thread_bindings=thread_bindings,
)

inst.MMA(inst, A_local_buf, B_local_buf, C_local_buf)

0 comments on commit 6a22442

Please sign in to comment.